hanzo-rocm-kernels 0.11.3

ROCm/HIP kernels for Hanzo
#ifndef __HIPCC__
#define __device__
#define __global__
#define __forceinline__
#else
#include <hip/hip_runtime.h>
#endif

#include <stddef.h>
#include <stdint.h>

#if defined(__HIPCC__)
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#endif

// Native positions-aware rotary embedding. Matches engine `apply_rotary_cpu_inner`
// (hanzo-quant/src/rotary/mod.rs) and hanzo_nn::rotary_emb::rope / rope_i.
//
// src is [b, h, t, d] contiguous. cos/sin are the full [max_pos, d/2] cache tables.
// positions is u32[batch]: the per-sequence start offset, so the cache row for
// token (batch_idx, seq_idx) is positions[batch_idx] + seq_idx (no host index build).
//
// One thread per (x,y) rotation pair; there are b*h*t*(d/2) pairs total. neox pairs
// are (i_d, i_d + d/2); gpt-j (interleaved) pairs are (2*i_d, 2*i_d + 1). All math in
// float so f16/bf16 stay accurate. Writes both q and k when k_ptr != nullptr (k shares
// the same shape contract, with kh heads instead of h).
template <typename T>
__device__ __forceinline__ void rope_positions(
    const T * q, const T * k, const T * cos, const T * sin,
    const uint32_t * positions, T * q_out, T * k_out,
    const unsigned int b, const unsigned int h, const unsigned int kh,
    const unsigned int t, const unsigned int d, const unsigned int is_neox) {
    const unsigned int half = d / 2;
    const unsigned int pairs_per_head = t * half;
    // Total pairs across the larger of the two head counts; each thread maps to one
    // (head, token, pair) slot and applies it to q and (if present) k.
    const unsigned int max_heads = h > kh ? h : kh;
    const unsigned int total = b * max_heads * pairs_per_head;
    const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= total) return;

    const unsigned int pair = idx % half;
    unsigned int tmp = idx / half;
    const unsigned int i_t = tmp % t;
    tmp /= t;
    const unsigned int i_head = tmp % max_heads;
    const unsigned int i_b = tmp / max_heads;

    const unsigned int cache_row = positions[i_b] + i_t;
    const unsigned int i_cs = cache_row * half + pair;
    const float c = (float)cos[i_cs];
    const float s = (float)sin[i_cs];

    unsigned int x_off, y_off;
    if (is_neox) {
        x_off = pair;
        y_off = pair + half;
    } else {
        x_off = pair * 2;
        y_off = pair * 2 + 1;
    }

    if (i_head < h) {
        const unsigned int base = ((i_b * h + i_head) * t + i_t) * d;
        const float x = (float)q[base + x_off];
        const float y = (float)q[base + y_off];
        q_out[base + x_off] = (T)(x * c - y * s);
        q_out[base + y_off] = (T)(y * c + x * s);
    }
    if (k != nullptr && i_head < kh) {
        const unsigned int base = ((i_b * kh + i_head) * t + i_t) * d;
        const float x = (float)k[base + x_off];
        const float y = (float)k[base + y_off];
        k_out[base + x_off] = (T)(x * c - y * s);
        k_out[base + y_off] = (T)(y * c + x * s);
    }
}

#define ROPE_POS_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
    const TYPENAME *q, const TYPENAME *k, const TYPENAME *cos, const TYPENAME *sin, \
    const uint32_t *positions, TYPENAME *q_out, TYPENAME *k_out, \
    const unsigned int b, const unsigned int h, const unsigned int kh, \
    const unsigned int t, const unsigned int d, const unsigned int is_neox) { \
    rope_positions<TYPENAME>(q, k, cos, sin, positions, q_out, k_out, b, h, kh, t, d, is_neox); \
}

ROPE_POS_OP(float, rope_positions_f32)

#if defined(__HIPCC__)
ROPE_POS_OP(__half, rope_positions_f16)
ROPE_POS_OP(hip_bfloat16, rope_positions_bf16)
#endif