hodu_cuda_kernels 0.2.4

hodu cuda kernels
#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <stdint.h>

// Math Constants

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#ifndef INFINITY
#define INFINITY __int_as_float(0x7f800000)
#endif

#ifndef NAN
#define NAN __int_as_float(0x7fc00000)
#endif

// Type Conversion Utilities

template <typename T> __device__ __forceinline__ float to_float(T val) {
    return static_cast<float>(val);
}

template <> __device__ __forceinline__ float to_float(__nv_fp8_e4m3 val) { return (float)val; }

template <> __device__ __forceinline__ float to_float(__nv_fp8_e5m2 val) { return (float)val; }

template <> __device__ __forceinline__ float to_float(__nv_bfloat16 val) {
    return __bfloat162float(val);
}

template <> __device__ __forceinline__ float to_float(__half val) { return __half2float(val); }

template <typename T> __device__ __forceinline__ T from_float(float val) {
    return static_cast<T>(val);
}

template <> __device__ __forceinline__ __nv_fp8_e4m3 from_float(float val) {
    return __nv_fp8_e4m3(val);
}

template <> __device__ __forceinline__ __nv_fp8_e5m2 from_float(float val) {
    return __nv_fp8_e5m2(val);
}

template <> __device__ __forceinline__ __nv_bfloat16 from_float(float val) {
    return __nv_bfloat16(val);
}

template <> __device__ __forceinline__ __half from_float(float val) { return __half(val); }

// Comparison Utilities

template <typename T> __device__ __forceinline__ T maximum(T x, T y) { return (x > y) ? x : y; }

template <> __device__ __forceinline__ __nv_fp8_e4m3 maximum(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y) {
    return (to_float(x) > to_float(y)) ? x : y;
}

template <> __device__ __forceinline__ __nv_fp8_e5m2 maximum(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y) {
    return (to_float(x) > to_float(y)) ? x : y;
}

template <> __device__ __forceinline__ __half maximum(__half x, __half y) {
    return (to_float(x) > to_float(y)) ? x : y;
}

template <> __device__ __forceinline__ __nv_bfloat16 maximum(__nv_bfloat16 x, __nv_bfloat16 y) {
    return (to_float(x) > to_float(y)) ? x : y;
}

template <typename T> __device__ __forceinline__ T minimum(T x, T y) { return (x < y) ? x : y; }

template <> __device__ __forceinline__ __nv_fp8_e4m3 minimum(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y) {
    return (to_float(x) < to_float(y)) ? x : y;
}

template <> __device__ __forceinline__ __nv_fp8_e5m2 minimum(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y) {
    return (to_float(x) < to_float(y)) ? x : y;
}

template <> __device__ __forceinline__ __half minimum(__half x, __half y) {
    return (to_float(x) < to_float(y)) ? x : y;
}

template <> __device__ __forceinline__ __nv_bfloat16 minimum(__nv_bfloat16 x, __nv_bfloat16 y) {
    return (to_float(x) < to_float(y)) ? x : y;
}

// Power and Exponential Functions

template <typename T, typename I> __device__ __forceinline__ T ipow(T base, I exp) {
    T result = 1;
    while (exp > 0) {
        if (exp & 1)
            result *= base;
        base *= base;
        exp >>= 1;
    }
    return result;
}

__device__ __forceinline__ float m_pow_float(float base, float exponent) {
    if (exponent == 0.0f)
        return 1.0f;
    if (base == 0.0f)
        return (exponent > 0.0f) ? 0.0f : INFINITY;
    if (base == 1.0f)
        return 1.0f;
    if (exponent == 1.0f)
        return base;

    if (floor(exponent) == exponent) {
        if (exponent >= 0.0f) {
            return ipow(base, (unsigned int)exponent);
        } else {
            return 1.0f / ipow(base, (unsigned int)(-exponent));
        }
    }

    if (base < 0.0f)
        return NAN;
    return powf(base, exponent);
}

// Trigonometric Functions

__device__ __forceinline__ float m_tan(float x) {
    x = fmodf(x, 2 * M_PI);
    if (x > M_PI)
        x -= 2 * M_PI;
    else if (x < -M_PI)
        x += 2 * M_PI;

    float halfPi = M_PI / 2;
    float eps = 1e-6f;

    if (fabsf(fabsf(x) - halfPi) < eps) {
        return x > 0 ? 1e6f : -1e6f;
    }

    return sinf(x) / cosf(x);
}

__device__ __forceinline__ float m_exp10(float x) { return exp10f(x); }

// Zero-checking Utilities

template <typename T> __device__ __forceinline__ bool is_nonzero(T val) { return val != T(0); }

template <> __device__ __forceinline__ bool is_nonzero(__nv_fp8_e4m3 val) {
    return to_float(val) != 0.0f;
}

template <> __device__ __forceinline__ bool is_nonzero(__nv_fp8_e5m2 val) {
    return to_float(val) != 0.0f;
}

template <> __device__ __forceinline__ bool is_nonzero(__nv_bfloat16 val) {
    return to_float(val) != 0.0f;
}

template <> __device__ __forceinline__ bool is_nonzero(__half val) { return to_float(val) != 0.0f; }

// Sign Function

template <typename T> __device__ __forceinline__ T sign(T x) {
    if (x > T(0))
        return T(1);
    if (x < T(0))
        return T(-1);
    return T(0);
}

template <> __device__ __forceinline__ __nv_fp8_e4m3 sign(__nv_fp8_e4m3 x) {
    float fx = to_float(x);
    return from_float<__nv_fp8_e4m3>(fx > 0.0f ? 1.0f : (fx < 0.0f ? -1.0f : 0.0f));
}

template <> __device__ __forceinline__ __nv_fp8_e5m2 sign(__nv_fp8_e5m2 x) {
    float fx = to_float(x);
    return from_float<__nv_fp8_e5m2>(fx > 0.0f ? 1.0f : (fx < 0.0f ? -1.0f : 0.0f));
}

template <> __device__ __forceinline__ __nv_bfloat16 sign(__nv_bfloat16 x) {
    float fx = to_float(x);
    return from_float<__nv_bfloat16>(fx > 0.0f ? 1.0f : (fx < 0.0f ? -1.0f : 0.0f));
}

template <> __device__ __forceinline__ __half sign(__half x) {
    float fx = to_float(x);
    return from_float<__half>(fx > 0.0f ? 1.0f : (fx < 0.0f ? -1.0f : 0.0f));
}

// Activation Functions

template <typename T> __device__ __forceinline__ T relu(T x) { return maximum(x, T(0)); }

// Specializations for types that don't have implicit int conversion
template <> __device__ __forceinline__ __nv_bfloat16 relu(__nv_bfloat16 x) {
    return maximum(x, __nv_bfloat16(0.0f));
}

template <> __device__ __forceinline__ __half relu(__half x) { return maximum(x, __half(0.0f)); }

template <typename T> __device__ __forceinline__ T sigmoid(T x) {
    float fx = to_float(x);
    return from_float<T>(1.0f / (1.0f + expf(-fx)));
}

template <typename T> __device__ __forceinline__ T gelu(T x) {
    float fx = to_float(x);
    float result =
        0.5f * fx * (1.0f + tanhf(0.7978845608028654f * (fx + 0.044715f * fx * fx * fx)));
    return from_float<T>(result);
}

template <typename T> __device__ __forceinline__ T softplus(T x) {
    float fx = to_float(x);
    return from_float<T>(logf(1.0f + expf(fx)));
}

template <typename T> __device__ __forceinline__ T silu(T x) {
    float fx = to_float(x);
    return from_float<T>(fx / (1.0f + expf(-fx)));
}

template <typename T> __device__ __forceinline__ T mish(T x) {
    float fx = to_float(x);
    return from_float<T>(fx * tanhf(logf(1.0f + expf(fx))));
}