hanzo-rocm-kernels 0.10.2

ROCm/HIP kernels for Hanzo
#ifndef __HIPCC__
#define __device__
#define __global__
#define __forceinline__
#else
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#endif

#include <stddef.h>
#include <stdint.h>
#include <string.h>

__device__ bool is_contiguous(
    const size_t num_dims,
    const size_t *dims,
    const size_t *strides
) {
    size_t acc = 1;
    for (unsigned int d = 0; d < num_dims; d++) {
        unsigned int dim_idx = num_dims - 1 - d;
        if (dims[dim_idx] > 1 && acc != strides[dim_idx]) {
            return false;
        }
        acc *= dims[dim_idx];
    }
    return true;
}

__device__ unsigned int get_strided_index(
    unsigned int idx,
    const size_t num_dims,
    const size_t *dims,
    const size_t *strides
) {
    unsigned int strided_i = 0;
    unsigned int tmp_i = idx;
    for (int d = num_dims - 1; d >= 0; d--) {
        unsigned int i_dim = tmp_i % dims[d];
        strided_i += i_dim * strides[d];
        tmp_i /= dims[d];
    }
    return strided_i;
}

struct __rocm_bf16 {
    uint16_t __x;

    __device__ __rocm_bf16() : __x(0) {}

    __device__ __rocm_bf16(float f) {
        unsigned int u;
        memcpy(&u, &f, sizeof(u));
        __x = (uint16_t)(u >> 16);
    }

    __device__ __rocm_bf16(double d) {
        float f = (float)d;
        unsigned int u;
        memcpy(&u, &f, sizeof(u));
        __x = (uint16_t)(u >> 16);
    }

    __device__ operator float() const {
        float f = 0.0f;
        unsigned int u = (unsigned int)__x << 16;
        memcpy(&f, &u, sizeof(f));
        return f;
    }
};

#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
    const size_t numel, \
    const size_t num_dims, \
    const size_t *info, \
    const SRC_TYPENAME *inp, \
    DST_TYPENAME *out \
) { \
    const size_t *dims = info; \
    const size_t *strides = info + num_dims; \
    if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            out[i] = (DST_TYPENAME)inp[i]; \
        } \
    } else { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            unsigned int strided_i = get_strided_index(i, num_dims, dims, strides); \
            out[i] = (DST_TYPENAME)inp[strided_i]; \
        } \
    } \
}

#define CAST_THROUGH_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
    const size_t numel, \
    const size_t num_dims, \
    const size_t *info, \
    const SRC_TYPENAME *inp, \
    DST_TYPENAME *out \
) { \
    const size_t *dims = info; \
    const size_t *strides = info + num_dims; \
    if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            out[i] = (DST_TYPENAME)(float)inp[i]; \
        } \
    } else { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            unsigned int strided_i = get_strided_index(i, num_dims, dims, strides); \
            out[i] = (DST_TYPENAME)(float)inp[strided_i]; \
        } \
    } \
}

CAST_OP(uint8_t, uint8_t, cast_u8_u8)
CAST_OP(uint8_t, uint32_t, cast_u8_u32)
CAST_OP(uint8_t, int64_t, cast_u8_i64)
CAST_OP(uint8_t, float, cast_u8_f32)
CAST_OP(uint8_t, double, cast_u8_f64)

CAST_OP(uint32_t, uint8_t, cast_u32_u8)
CAST_OP(uint32_t, uint32_t, cast_u32_u32)
CAST_OP(uint32_t, int64_t, cast_u32_i64)
CAST_OP(uint32_t, float, cast_u32_f32)
CAST_OP(uint32_t, double, cast_u32_f64)

CAST_OP(int64_t, uint8_t, cast_i64_u8)
CAST_OP(int64_t, uint32_t, cast_i64_u32)
CAST_OP(int64_t, int64_t, cast_i64_i64)
CAST_OP(int64_t, float, cast_i64_f32)
CAST_OP(int64_t, double, cast_i64_f64)

CAST_OP(float, uint8_t, cast_f32_u8)
CAST_OP(float, uint32_t, cast_f32_u32)
CAST_OP(float, int64_t, cast_f32_i64)
CAST_OP(float, float, cast_f32_f32)
CAST_OP(float, double, cast_f32_f64)

CAST_OP(double, uint8_t, cast_f64_u8)
CAST_OP(double, uint32_t, cast_f64_u32)
CAST_OP(double, int64_t, cast_f64_i64)
CAST_OP(double, float, cast_f64_f32)
CAST_OP(double, double, cast_f64_f64)

CAST_THROUGH_OP(__rocm_bf16, __rocm_bf16, cast_bf16_bf16)
CAST_THROUGH_OP(__rocm_bf16, uint8_t, cast_bf16_u8)
CAST_THROUGH_OP(__rocm_bf16, uint32_t, cast_bf16_u32)
CAST_THROUGH_OP(__rocm_bf16, int64_t, cast_bf16_i64)
CAST_THROUGH_OP(__rocm_bf16, float, cast_bf16_f32)
CAST_THROUGH_OP(__rocm_bf16, double, cast_bf16_f64)
CAST_THROUGH_OP(__rocm_bf16, half, cast_bf16_f16)

CAST_THROUGH_OP(uint8_t, __rocm_bf16, cast_u8_bf16)
CAST_THROUGH_OP(uint32_t, __rocm_bf16, cast_u32_bf16)
CAST_THROUGH_OP(int64_t, __rocm_bf16, cast_i64_bf16)
CAST_THROUGH_OP(float, __rocm_bf16, cast_f32_bf16)
CAST_THROUGH_OP(double, __rocm_bf16, cast_f64_bf16)
CAST_THROUGH_OP(half, __rocm_bf16, cast_f16_bf16)

CAST_THROUGH_OP(half, half, cast_f16_f16)
CAST_THROUGH_OP(half, uint8_t, cast_f16_u8)
CAST_THROUGH_OP(half, uint32_t, cast_f16_u32)
CAST_THROUGH_OP(half, int64_t, cast_f16_i64)
CAST_THROUGH_OP(half, float, cast_f16_f32)
CAST_THROUGH_OP(half, double, cast_f16_f64)

CAST_THROUGH_OP(uint8_t, half, cast_u8_f16)
CAST_THROUGH_OP(uint32_t, half, cast_u32_f16)
CAST_THROUGH_OP(int64_t, half, cast_i64_f16)
CAST_THROUGH_OP(float, half, cast_f32_f16)
CAST_THROUGH_OP(double, half, cast_f64_f16)