hanzo-rocm-kernels 0.10.2

ROCm/HIP kernels for Hanzo
#ifndef __HIPCC__
#define __device__
#define __global__
#define __forceinline__
#else
#include <hip/hip_runtime.h>
#endif

#include <stddef.h>
#include <stdint.h>

__device__ bool is_contiguous(
    const size_t num_dims,
    const size_t *dims,
    const size_t *strides
) {
    size_t acc = 1;
    for (unsigned int d = 0; d < num_dims; d++) {
        unsigned int dim_idx = num_dims - 1 - d;
        if (dims[dim_idx] > 1 && acc != strides[dim_idx]) {
            return false;
        }
        acc *= dims[dim_idx];
    }
    return true;
}

__device__ unsigned int get_strided_index(
    unsigned int idx,
    const size_t num_dims,
    const size_t *dims,
    const size_t *strides
) {
    unsigned int strided_i = 0;
    unsigned int tmp_i = idx;
    for (int d = num_dims - 1; d >= 0; d--) {
        unsigned int i_dim = tmp_i % dims[d];
        strided_i += i_dim * strides[d];
        tmp_i /= dims[d];
    }
    return strided_i;
}

#define FILL_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
    TYPENAME *buf, \
    TYPENAME value, \
    const size_t numel \
) { \
    for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
        buf[i] = value; \
    } \
}

#define CONST_SET_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
    const size_t numel, \
    const size_t num_dims, \
    const size_t *info, \
    const TYPENAME inp, \
    TYPENAME *out \
) { \
    if (info == nullptr || is_contiguous(num_dims, info, info + num_dims)) { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            out[i] = inp; \
        } \
    } else { \
        const size_t *strides = info + num_dims; \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            unsigned int dst_i = get_strided_index(i, num_dims, info, strides); \
            out[dst_i] = inp; \
        } \
    } \
}

FILL_OP(float, fill_f32)
FILL_OP(double, fill_f64)
FILL_OP(uint8_t, fill_u8)
FILL_OP(uint32_t, fill_u32)
FILL_OP(int64_t, fill_i64)

CONST_SET_OP(float, const_set_f32)
CONST_SET_OP(double, const_set_f64)
CONST_SET_OP(uint8_t, const_set_u8)
CONST_SET_OP(uint32_t, const_set_u32)
CONST_SET_OP(int64_t, const_set_i64)

// 16-bit float variants (plain assignment, no arithmetic).
#if defined(__HIPCC__)
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
FILL_OP(__half, fill_f16)
FILL_OP(hip_bfloat16, fill_bf16)
CONST_SET_OP(__half, const_set_f16)
CONST_SET_OP(hip_bfloat16, const_set_bf16)
#endif