hanzo-rocm-kernels 0.10.2

ROCm/HIP kernels for Hanzo
#ifndef __HIPCC__
#define __device__
#define __global__
#define __forceinline__
#else
#include <hip/hip_runtime.h>
#endif

#include <stddef.h>
#include <stdint.h>
#include <math.h>

__device__ bool is_contiguous(
    const size_t num_dims,
    const size_t *dims,
    const size_t *strides
) {
    size_t acc = 1;
    for (unsigned int d = 0; d < num_dims; d++) {
        unsigned int dim_idx = num_dims - 1 - d;
        if (dims[dim_idx] > 1 && acc != strides[dim_idx]) {
            return false;
        }
        acc *= dims[dim_idx];
    }
    return true;
}

#define UNARY_OP(TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
    const size_t numel, \
    const size_t num_dims, \
    const size_t *dims_and_strides, \
    const TYPENAME *input, \
    TYPENAME *out \
) { \
    const size_t *dims = dims_and_strides; \
    const size_t *strides = dims_and_strides + num_dims; \
    bool contiguous = dims_and_strides == nullptr || is_contiguous(num_dims, dims, strides); \
    if (contiguous) { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            TYPENAME x = input[i]; \
            out[i] = FUNC; \
        } \
    } else { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            unsigned int tmp_i = i; \
            unsigned int src_i = 0; \
            for (int d = num_dims - 1; d >= 0; d--) { \
                unsigned int i_dim = tmp_i % dims[d]; \
                src_i += i_dim * strides[d]; \
                tmp_i /= dims[d]; \
            } \
            TYPENAME x = input[src_i]; \
            out[i] = FUNC; \
        } \
    } \
}

#define UNARY_OP_SCALAR(TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
    const size_t numel, \
    const size_t num_dims, \
    const size_t *dims_and_strides, \
    const TYPENAME *input, \
    TYPENAME *out, \
    TYPENAME scalar \
) { \
    const size_t *dims = dims_and_strides; \
    const size_t *strides = dims_and_strides + num_dims; \
    bool contiguous = dims_and_strides == nullptr || is_contiguous(num_dims, dims, strides); \
    if (contiguous) { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            TYPENAME x = input[i]; \
            out[i] = FUNC; \
        } \
    } else { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            unsigned int tmp_i = i; \
            unsigned int src_i = 0; \
            for (int d = num_dims - 1; d >= 0; d--) { \
                unsigned int i_dim = tmp_i % dims[d]; \
                src_i += i_dim * strides[d]; \
                tmp_i /= dims[d]; \
            } \
            TYPENAME x = input[src_i]; \
            out[i] = FUNC; \
        } \
    } \
}

UNARY_OP(float, ucopy_f32, x)
UNARY_OP(double, ucopy_f64, x)
UNARY_OP(uint8_t, ucopy_u8, x)
UNARY_OP(uint32_t, ucopy_u32, x)
UNARY_OP(int64_t, ucopy_i64, x)

UNARY_OP(float, urelu_f32, (x > 0.0f ? x : 0.0f))
UNARY_OP(double, urelu_f64, (x > 0.0 ? x : 0.0))

UNARY_OP(float, usigmoid_f32, (1.0f / (1.0f + expf(-x))))
UNARY_OP(double, usigmoid_f64, (1.0 / (1.0 + exp(-x))))

UNARY_OP(float, utan_f32, tanhf(x))
UNARY_OP(double, utan_f64, tanh(x))

UNARY_OP(float, uexp_f32, expf(x))
UNARY_OP(double, uexp_f64, exp(x))

UNARY_OP(float, ulog_f32, logf(x))
UNARY_OP(double, ulog_f64, log(x))

UNARY_OP(float, usin_f32, sinf(x))
UNARY_OP(double, usin_f64, sin(x))

UNARY_OP(float, ucos_f32, cosf(x))
UNARY_OP(double, ucos_f64, cos(x))

UNARY_OP(float, usqrt_f32, sqrtf(x))
UNARY_OP(double, usqrt_f64, sqrt(x))

UNARY_OP(float, uabs_f32, fabsf(x))
UNARY_OP(double, uabs_f64, fabs(x))
UNARY_OP(int64_t, uabs_i64, (x < 0 ? -x : x))

UNARY_OP(float, uneg_f32, -x)
UNARY_OP(double, uneg_f64, -x)
UNARY_OP(int64_t, uneg_i64, -x)

UNARY_OP(float, urecip_f32, (1.0f / x))
UNARY_OP(double, urecip_f64, (1.0 / x))

UNARY_OP(float, ufloor_f32, floorf(x))
UNARY_OP(double, ufloor_f64, floor(x))

UNARY_OP(float, uceil_f32, ceilf(x))
UNARY_OP(double, uceil_f64, ceil(x))

UNARY_OP(float, uround_f32, roundf(x))
UNARY_OP(double, uround_f64, round(x))

UNARY_OP(float, usqr_f32, (x * x))
UNARY_OP(double, usqr_f64, (x * x))

UNARY_OP(float, usign_f32, (x > 0.0f ? 1.0f : (x < 0.0f ? -1.0f : 0.0f)))
UNARY_OP(double, usign_f64, (x > 0.0 ? 1.0 : (x < 0.0 ? -1.0 : 0.0)))

UNARY_OP(float, ugelu_f32, (0.5f * x * (1.0f + tanhf(0.7978845608f * (x + 0.044715f * x * x * x)))))
UNARY_OP(double, ugelu_f64, (0.5 * x * (1.0 + tanh(0.7978845608028654 * (x + 0.044715 * x * x * x)))))

UNARY_OP(float, ugelu_erf_f32, (0.5f * x * (1.0f + erff(x * 0.7071067811865475f))))
UNARY_OP(double, ugelu_erf_f64, (0.5 * x * (1.0 + erf(x * 0.7071067811865475))))

UNARY_OP(float, usilu_f32, (x / (1.0f + expf(-x))))
UNARY_OP(double, usilu_f64, (x / (1.0 + exp(-x))))

UNARY_OP(float, uerf_f32, erff(x))
UNARY_OP(double, uerf_f64, erf(x))

UNARY_OP_SCALAR(float, uelu_f32, (x > 0.0f ? x : scalar * (expf(x) - 1.0f)))
UNARY_OP_SCALAR(double, uelu_f64, (x > 0.0 ? x : scalar * (exp(x) - 1.0)))

UNARY_OP_SCALAR(float, upowf_f32, powf(x, scalar))
UNARY_OP_SCALAR(double, upowf_f64, pow(x, scalar))

// 16-bit float variants: load/compute in float, cast result back.
#if defined(__HIPCC__)
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>

#define UNARY_OP_F(TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
    const size_t numel, \
    const size_t num_dims, \
    const size_t *dims_and_strides, \
    const TYPENAME *input, \
    TYPENAME *out \
) { \
    const size_t *dims = dims_and_strides; \
    const size_t *strides = dims_and_strides + num_dims; \
    bool contiguous = dims_and_strides == nullptr || is_contiguous(num_dims, dims, strides); \
    if (contiguous) { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            float x = (float)input[i]; \
            out[i] = (TYPENAME)(FUNC); \
        } \
    } else { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            unsigned int tmp_i = i; \
            unsigned int src_i = 0; \
            for (int d = num_dims - 1; d >= 0; d--) { \
                unsigned int i_dim = tmp_i % dims[d]; \
                src_i += i_dim * strides[d]; \
                tmp_i /= dims[d]; \
            } \
            float x = (float)input[src_i]; \
            out[i] = (TYPENAME)(FUNC); \
        } \
    } \
}

#define UNARY_OP_SCALAR_F(TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
    const size_t numel, \
    const size_t num_dims, \
    const size_t *dims_and_strides, \
    const TYPENAME *input, \
    TYPENAME *out, \
    TYPENAME scalar_in \
) { \
    const size_t *dims = dims_and_strides; \
    const size_t *strides = dims_and_strides + num_dims; \
    float scalar = (float)scalar_in; \
    bool contiguous = dims_and_strides == nullptr || is_contiguous(num_dims, dims, strides); \
    if (contiguous) { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            float x = (float)input[i]; \
            out[i] = (TYPENAME)(FUNC); \
        } \
    } else { \
        for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
            unsigned int tmp_i = i; \
            unsigned int src_i = 0; \
            for (int d = num_dims - 1; d >= 0; d--) { \
                unsigned int i_dim = tmp_i % dims[d]; \
                src_i += i_dim * strides[d]; \
                tmp_i /= dims[d]; \
            } \
            float x = (float)input[src_i]; \
            out[i] = (TYPENAME)(FUNC); \
        } \
    } \
}

#define UNARY_ALL_F(T, SUF) \
UNARY_OP_F(T, ucopy_##SUF, x) \
UNARY_OP_F(T, urelu_##SUF, (x > 0.0f ? x : 0.0f)) \
UNARY_OP_F(T, usigmoid_##SUF, (1.0f / (1.0f + expf(-x)))) \
UNARY_OP_F(T, utan_##SUF, tanhf(x)) \
UNARY_OP_F(T, uexp_##SUF, expf(x)) \
UNARY_OP_F(T, ulog_##SUF, logf(x)) \
UNARY_OP_F(T, usin_##SUF, sinf(x)) \
UNARY_OP_F(T, ucos_##SUF, cosf(x)) \
UNARY_OP_F(T, usqrt_##SUF, sqrtf(x)) \
UNARY_OP_F(T, uabs_##SUF, fabsf(x)) \
UNARY_OP_F(T, uneg_##SUF, -x) \
UNARY_OP_F(T, urecip_##SUF, (1.0f / x)) \
UNARY_OP_F(T, ufloor_##SUF, floorf(x)) \
UNARY_OP_F(T, uceil_##SUF, ceilf(x)) \
UNARY_OP_F(T, uround_##SUF, roundf(x)) \
UNARY_OP_F(T, usqr_##SUF, (x * x)) \
UNARY_OP_F(T, usign_##SUF, (x > 0.0f ? 1.0f : (x < 0.0f ? -1.0f : 0.0f))) \
UNARY_OP_F(T, ugelu_##SUF, (0.5f * x * (1.0f + tanhf(0.7978845608f * (x + 0.044715f * x * x * x))))) \
UNARY_OP_F(T, ugelu_erf_##SUF, (0.5f * x * (1.0f + erff(x * 0.7071067811865475f)))) \
UNARY_OP_F(T, usilu_##SUF, (x / (1.0f + expf(-x)))) \
UNARY_OP_F(T, uerf_##SUF, erff(x)) \
UNARY_OP_SCALAR_F(T, uelu_##SUF, (x > 0.0f ? x : scalar * (expf(x) - 1.0f))) \
UNARY_OP_SCALAR_F(T, upowf_##SUF, powf(x, scalar))

UNARY_ALL_F(__half, f16)
UNARY_ALL_F(hip_bfloat16, bf16)
#endif