hanzo-rocm-kernels 0.10.2

ROCm/HIP kernels for Hanzo
// Convolution kernels for ROCm

template <typename T>
__device__ void im2col1d(
    size_t threads,
    size_t l_out,
    size_t l_k,
    size_t stride,
    size_t padding,
    size_t dilation,
    const size_t *dims,
    const size_t *strides,
    const T *src,
    T *dst
) {
    size_t b = dims[0];
    size_t c_in = dims[1];
    size_t l_in = dims[2];

    size_t src_s0 = strides[0];
    size_t src_s1 = strides[1];
    size_t src_s2 = strides[2];

    size_t dst_s0 = c_in * l_k;
    size_t dst_s1 = l_k;

    size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
    if (tid >= threads) return;

    size_t out_idx = tid;
    size_t out_ch = out_idx % c_in;
    out_idx /= c_in;
    size_t l_idx = out_idx % l_out;
    out_idx /= l_out;
    size_t b_idx = out_idx;

    for (size_t k_idx = 0; k_idx < l_k; k_idx++) {
        size_t l_in_idx = l_idx * stride + k_idx * dilation;
        if (l_in_idx >= padding && l_in_idx < l_in + padding) {
            size_t src_idx = b_idx * src_s0 + out_ch * src_s1 + (l_in_idx - padding) * src_s2;
            size_t dst_idx = (b_idx * l_out + l_idx) * dst_s0 + out_ch * dst_s1 + k_idx;
            dst[dst_idx] = src[src_idx];
        } else {
            size_t dst_idx = (b_idx * l_out + l_idx) * dst_s0 + out_ch * dst_s1 + k_idx;
            dst[dst_idx] = 0;
        }
    }
}

#define IM2COL1D(T, name) \
    extern "C" __global__ void im2col1d_##name( \
        size_t threads, \
        size_t l_out, \
        size_t l_k, \
        size_t stride, \
        size_t padding, \
        size_t dilation, \
        const size_t *dims, \
        const size_t *strides, \
        const T *src, \
        T *dst \
    ) { \
        im2col1d<T>(threads, l_out, l_k, stride, padding, dilation, dims, strides, src, dst); \
    }

IM2COL1D(float, f32)
IM2COL1D(double, f64)
IM2COL1D(half, f16)

template <typename T, typename A>
__device__ void conv_transpose1d(
    const size_t src_numel,
    const size_t l_out,
    const size_t stride,
    const size_t padding,
    const size_t out_padding,
    const size_t dilation,
    const size_t *info,
    const T *src,
    const T *kernel,
    T *dst
) {
    const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
    const size_t *src_dims = info;
    const size_t *src_s = info + 3;
    const size_t *k_dims = info + 6;
    const size_t *k_s = info + 9;
    const size_t l_k = k_dims[2];
    const size_t c_out = k_dims[1];
    const size_t c_in = src_dims[1];
    const size_t l_in = src_dims[2];
    if (dst_i >= src_dims[0] * c_out * l_out) {
        return;
    }

    const size_t b_idx = dst_i / (l_out * c_out);
    const size_t dst_c_idx = (dst_i / l_out) % c_out;
    const size_t out_x = dst_i % l_out;

    const size_t src_idx0 = b_idx * src_s[0];
    A d = 0;
    for (int k_x = 0; k_x < (int)l_k; ++k_x) {
        int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
        if (inp_x_stride < 0 || inp_x_stride % (int)stride) {
            continue;
        }
        int inp_x = inp_x_stride / (int)stride;
        if (inp_x >= (int)l_in) continue;
        for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
            const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_x * src_s[2];
            const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_x * k_s[2];
            d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
        }
    }
    dst[dst_i] = static_cast<T>(d);
}

#define CONV_TRANSPOSE1D(T, A, name) \
extern "C" __global__ void conv_transpose1d_##name( \
    const size_t src_numel, \
    const size_t l_out, \
    const size_t stride, \
    const size_t padding, \
    const size_t out_padding, \
    const size_t dilation, \
    const size_t *info, \
    const T *src, \
    const T *kernel, \
    T *dst \
) { \
    conv_transpose1d<T, A>(src_numel, l_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
}

CONV_TRANSPOSE1D(float, float, f32)
CONV_TRANSPOSE1D(double, double, f64)
CONV_TRANSPOSE1D(half, float, f16)

#include <hip/hip_bfloat16.h>
__device__ __forceinline__ float bfloat162float(hip_bfloat16 v) { return (float)v; }
__device__ __forceinline__ hip_bfloat16 float2bfloat16(float v) { return hip_bfloat16(v); }

template <>
__device__ void conv_transpose1d<hip_bfloat16, float>(
    const size_t src_numel,
    const size_t l_out,
    const size_t stride,
    const size_t padding,
    const size_t out_padding,
    const size_t dilation,
    const size_t *info,
    const hip_bfloat16 *src,
    const hip_bfloat16 *kernel,
    hip_bfloat16 *dst
) {
    const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
    const size_t *src_dims = info;
    const size_t *src_s = info + 3;
    const size_t *k_dims = info + 6;
    const size_t *k_s = info + 9;
    const size_t l_k = k_dims[2];
    const size_t c_out = k_dims[1];
    const size_t c_in = src_dims[1];
    const size_t l_in = src_dims[2];
    if (dst_i >= src_dims[0] * c_out * l_out) {
        return;
    }

    const size_t b_idx = dst_i / (l_out * c_out);
    const size_t dst_c_idx = (dst_i / l_out) % c_out;
    const size_t out_x = dst_i % l_out;

    const size_t src_idx0 = b_idx * src_s[0];
    float d = 0;
    for (int k_x = 0; k_x < (int)l_k; ++k_x) {
        int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
        if (inp_x_stride < 0 || inp_x_stride % (int)stride) {
            continue;
        }
        int inp_x = inp_x_stride / (int)stride;
        if (inp_x >= (int)l_in) continue;
        for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
            const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_x * src_s[2];
            const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_x * k_s[2];
            d += bfloat162float(src[src_idx]) * bfloat162float(kernel[k_idx]);
        }
    }
    dst[dst_i] = float2bfloat16(d);
}

extern "C" __global__ void conv_transpose1d_bf16(
    const size_t src_numel,
    const size_t l_out,
    const size_t stride,
    const size_t padding,
    const size_t out_padding,
    const size_t dilation,
    const size_t *info,
    const hip_bfloat16 *src,
    const hip_bfloat16 *kernel,
    hip_bfloat16 *dst
) {
    conv_transpose1d<hip_bfloat16, float>(src_numel, l_out, stride, padding, out_padding, dilation, info, src, kernel, dst);
}