oxibonsai-kernels 0.1.1

1-bit Q1_0_g128 compute kernels (dequant, GEMV, GEMM) for OxiBonsai
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
//! CUDA C kernel source strings for OxiBonsai attention operations.
//!
//! # Attention kernel inventory
//!
//! | Kernel                      | Grid                                     | Block      | Notes                                           |
//! |-----------------------------|------------------------------------------|------------|-------------------------------------------------|
//! | `fused_qk_norm`             | `(nq+nkv, 1, 1)`                         | `(256,1,1)`| RMSNorm both Q and K heads in one dispatch      |
//! | `fused_qk_rope`             | `(ceil(half_dim/64), nq+nkv, 1)`         | `(64,1,1)` | Rotary position embedding for Q and K           |
//! | `fused_qk_norm_rope`        | `(nq+nkv, 1, 1)`                         | `(256,1,1)`| Combined RMSNorm + RoPE for Q and K             |
//! | `fused_kv_store`            | `(ceil(head_dim/64), nkv, 1)`            | `(64,1,1)` | Store K (after RoPE) and V into FP16 KV cache   |
//! | `batched_attn_scores_v2`    | `(n_q, ceil(seq_len/batch_stride), 1)`   | `(128,1,1)`| Batched dot-product attention scores            |
//! | `batched_softmax`           | `(n_q, 1, 1)`                            | `(256,1,1)`| Per-head numerically-stable softmax             |
//! | `batched_attn_weighted_sum` | `(ceil(head_dim/64), n_q, 1)`            | `(64,1,1)` | Weighted sum of V vectors by softmax weights    |
//!
//! # FP16 KV cache
//!
//! K and V are stored in FP16 (u16 on the Rust side) to halve VRAM usage.
//! `fused_kv_store` writes FP16; the score and weighted-sum kernels read FP16.
//!
//! # Weight layout
//!
//! Q/K/V/O/gate+up/down weight blocks use the Q1_0_G128 SoA layout identical
//! to the FFN kernels in `cuda_kernels.rs`.

#![cfg(all(
    feature = "native-cuda",
    any(target_os = "linux", target_os = "windows")
))]

/// Combined CUDA C source for all 7 attention kernels used in the decode path.
///
/// Compiled once at process startup (lazily on first `encode_attn_phase` call)
/// via `cudarc::nvrtc::compile_ptx`.  All kernels share a single PTX module.
pub const CUDA_ATTENTION_KERNELS_SRC: &str = r#"
/* =========================================================================
   OxiBonsai CUDA attention kernels — decode path
   All math intrinsics (sqrtf, expf, fmaxf, rsqrtf) are NVRTC built-ins;
   no CUDA SDK headers are required or included.
   ========================================================================= */

/* ── FP16 ↔ FP32 helpers ─────────────────────────────────────────────────── */
static __device__ __forceinline__ float fast_fp16_to_float(unsigned short h) {
    float f;
    asm("cvt.f32.f16 %0, %1;" : "=f"(f) : "h"(h));
    return f;
}

static __device__ __forceinline__ unsigned short float_to_fp16(float f) {
    unsigned short h;
    asm("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f));
    return h;
}

/* =========================================================================
   Kernel 1 — fused_qk_norm
   Fused RMSNorm for Q and K heads.

   Grid:  (nq + nkv, 1, 1)
   Block: (256, 1, 1)
   First nq CTAs normalise Q heads; remaining nkv CTAs normalise K heads.
   Each CTA uses shared-memory parallel reduction for sum-of-squares.
   ========================================================================= */
extern "C" __global__ void fused_qk_norm(
    const float* __restrict__ q_in,
    const float* __restrict__ k_in,
    float*       __restrict__ q_out,
    float*       __restrict__ k_out,
    const float* __restrict__ q_weight,
    const float* __restrict__ k_weight,
    unsigned int nq,
    unsigned int nkv,
    unsigned int head_dim,
    float        eps
) {
    const unsigned int gid = blockIdx.x;
    const unsigned int tid = threadIdx.x;
    const unsigned int tpg = blockDim.x;

    const bool is_q = (gid < nq);
    const unsigned int head_idx = is_q ? gid : (gid - nq);

    const float* __restrict__ in_ptr  = is_q ? (q_in + (unsigned long long)head_idx * head_dim)
                                              : (k_in + (unsigned long long)head_idx * head_dim);
    float*       __restrict__ out_ptr = is_q ? (q_out + (unsigned long long)head_idx * head_dim)
                                              : (k_out + (unsigned long long)head_idx * head_dim);
    const float* __restrict__ w_ptr   = is_q ? q_weight : k_weight;

    __shared__ float shared_sum[256];
    float local_sq = 0.0f;
    for (unsigned int i = tid; i < head_dim; i += tpg) {
        float v = in_ptr[i];
        local_sq += v * v;
    }
    shared_sum[tid] = local_sq;
    __syncthreads();

    for (unsigned int stride = tpg >> 1u; stride > 0u; stride >>= 1u) {
        if (tid < stride) shared_sum[tid] += shared_sum[tid + stride];
        __syncthreads();
    }

    float rms_inv = rsqrtf(shared_sum[0] / (float)head_dim + eps);

    for (unsigned int i = tid; i < head_dim; i += tpg) {
        out_ptr[i] = in_ptr[i] * rms_inv * w_ptr[i];
    }
}

/* =========================================================================
   Kernel 2 — fused_qk_rope
   Apply rotary position embedding to Q and K.

   Grid:  (ceil(half_dim/64), nq + nkv, 1)
   Block: (64, 1, 1)
   blockIdx.y < nq  →  apply to Q head blockIdx.y
   blockIdx.y >= nq →  apply to K head (blockIdx.y - nq)
   ========================================================================= */
extern "C" __global__ void fused_qk_rope(
    const float* __restrict__ q_in,
    const float* __restrict__ k_in,
    float*       __restrict__ q_out,
    float*       __restrict__ k_out,
    const float* __restrict__ cos_buf,
    const float* __restrict__ sin_buf,
    unsigned int nq,
    unsigned int nkv,
    unsigned int half_dim
) {
    const unsigned int d       = blockIdx.x * blockDim.x + threadIdx.x;
    const unsigned int head_y  = blockIdx.y;
    if (d >= half_dim) return;

    const bool is_q = (head_y < nq);
    const unsigned int head_idx = is_q ? head_y : (head_y - nq);
    const unsigned int head_dim = half_dim * 2u;

    const float* __restrict__ in_ptr  = is_q ? (q_in + (unsigned long long)head_idx * head_dim)
                                              : (k_in + (unsigned long long)head_idx * head_dim);
    float*       __restrict__ out_ptr = is_q ? (q_out + (unsigned long long)head_idx * head_dim)
                                              : (k_out + (unsigned long long)head_idx * head_dim);

    float c  = cos_buf[d];
    float s  = sin_buf[d];
    float x0 = in_ptr[d];
    float x1 = in_ptr[d + half_dim];
    out_ptr[d]          = x0 * c - x1 * s;
    out_ptr[d + half_dim] = x0 * s + x1 * c;
}

/* =========================================================================
   Kernel 3 — fused_qk_norm_rope
   Combined RMSNorm + RoPE for Q and K in a single dispatch.

   Grid:  (nq + nkv, 1, 1)
   Block: (256, 1, 1)
   Each CTA processes one Q or K head:
     Phase 1: parallel sum-of-squares → rms_inv
     Phase 2: normalise each element pair and apply RoPE in-place
   ========================================================================= */
extern "C" __global__ void fused_qk_norm_rope(
    const float* __restrict__ q_in,
    const float* __restrict__ k_in,
    float*       __restrict__ q_out,
    float*       __restrict__ k_out,
    const float* __restrict__ q_weight,
    const float* __restrict__ k_weight,
    const float* __restrict__ cos_buf,
    const float* __restrict__ sin_buf,
    unsigned int nq,
    unsigned int nkv,
    unsigned int head_dim,
    float        eps
) {
    const unsigned int gid = blockIdx.x;
    const unsigned int tid = threadIdx.x;
    const unsigned int tpg = blockDim.x;

    const bool is_q = (gid < nq);
    const unsigned int head_idx = is_q ? gid : (gid - nq);
    const unsigned int half_dim = head_dim >> 1u;

    const float* __restrict__ in_ptr  = is_q ? (q_in + (unsigned long long)head_idx * head_dim)
                                              : (k_in + (unsigned long long)head_idx * head_dim);
    float*       __restrict__ out_ptr = is_q ? (q_out + (unsigned long long)head_idx * head_dim)
                                              : (k_out + (unsigned long long)head_idx * head_dim);
    const float* __restrict__ w_ptr   = is_q ? q_weight : k_weight;

    /* Phase 1: sum of squares */
    __shared__ float shared_sum[256];
    float local_sq = 0.0f;
    for (unsigned int i = tid; i < head_dim; i += tpg) {
        float v = in_ptr[i];
        local_sq += v * v;
    }
    shared_sum[tid] = local_sq;
    __syncthreads();

    for (unsigned int stride = tpg >> 1u; stride > 0u; stride >>= 1u) {
        if (tid < stride) shared_sum[tid] += shared_sum[tid + stride];
        __syncthreads();
    }

    float rms_inv = rsqrtf(shared_sum[0] / (float)head_dim + eps);

    /* Phase 2: normalise + RoPE in one pass over [0, half_dim) */
    for (unsigned int d = tid; d < half_dim; d += tpg) {
        float normed_lo = in_ptr[d]          * rms_inv * w_ptr[d];
        float normed_hi = in_ptr[d + half_dim] * rms_inv * w_ptr[d + half_dim];

        float c = cos_buf[d];
        float s = sin_buf[d];
        out_ptr[d]          = normed_lo * c - normed_hi * s;
        out_ptr[d + half_dim] = normed_lo * s + normed_hi * c;
    }
}

/* =========================================================================
   Kernel 4 — fused_kv_store
   Store K (post-RoPE) and V into the FP16 KV cache in a single dispatch.

   Grid:  (ceil(head_dim/64), nkv, 1)
   Block: (64, 1, 1)
   Each thread stores one element of both K and V for one head.

   k_cache / v_cache layout: [n_layers × nkv × max_seq × head_dim]  (FP16)
   dst_offset = layer_offset + (head * max_seq + pos) * head_dim + d
   ========================================================================= */
extern "C" __global__ void fused_kv_store(
    const float*          __restrict__ k_data,
    const float*          __restrict__ v_data,
    unsigned short*       __restrict__ k_cache,
    unsigned short*       __restrict__ v_cache,
    unsigned int head_dim,
    unsigned int nkv,
    unsigned int max_seq,
    const unsigned int* __restrict__ d_pos_seqlen,
    unsigned int layer_offset
) {
    const unsigned int pos  = d_pos_seqlen[0];
    const unsigned int d    = blockIdx.x * blockDim.x + threadIdx.x;
    const unsigned int head = blockIdx.y;
    if (d >= head_dim || head >= nkv) return;

    const unsigned long long src_off = (unsigned long long)head * head_dim + d;
    const unsigned long long dst_off = (unsigned long long)layer_offset
                                     + ((unsigned long long)head * max_seq + pos) * head_dim
                                     + d;

    k_cache[dst_off] = float_to_fp16(k_data[src_off]);
    v_cache[dst_off] = float_to_fp16(v_data[src_off]);
}

/* =========================================================================
   Kernel 5 — batched_attn_scores_v2
   Batched dot-product attention scores with Q-vector shared-memory caching.

   Grid:  (n_q, ceil(seq_len / batch_stride), 1)
   Block: (128, 1, 1)

   Each CTA handles one Q-head and processes `batch_stride` positions.
   The Q vector is loaded into shared memory once and reused for each position.
   128 threads = head_dim, so every thread is active (no idle threads).
   ========================================================================= */
extern "C" __global__ void batched_attn_scores_v2(
    const float*          __restrict__ queries,
    const unsigned short* __restrict__ k_cache,
    float*                __restrict__ all_scores,
    unsigned int head_dim,
    unsigned int n_q,
    unsigned int n_kv,
    unsigned int heads_per_group,
    unsigned int max_seq,
    const unsigned int* __restrict__ d_pos_seqlen,
    float        inv_sqrt_hd,
    unsigned int cache_layer_offset,
    unsigned int batch_stride
) {
    const unsigned int seq_len  = d_pos_seqlen[1];
    const unsigned int q_head   = blockIdx.x;
    const unsigned int batch_id = blockIdx.y;
    const unsigned int tid      = threadIdx.x;
    if (q_head >= n_q) return;

    const unsigned int kv_head  = q_head / heads_per_group;
    const unsigned int pos_start = batch_id * batch_stride;

    /* Load Q vector into shared memory — reused across all positions */
    __shared__ float shared_q[128];
    if (tid < head_dim) {
        shared_q[tid] = queries[(unsigned long long)q_head * head_dim + tid];
    }
    __syncthreads();

    for (unsigned int pos_t = pos_start;
         pos_t < pos_start + batch_stride && pos_t < seq_len;
         pos_t++)
    {
        const unsigned short* __restrict__ key =
            k_cache + (unsigned long long)cache_layer_offset
            + ((unsigned long long)kv_head * max_seq + pos_t) * head_dim;

        float my_prod = 0.0f;
        if (tid < head_dim) {
            my_prod = shared_q[tid] * fast_fp16_to_float(key[tid]);
        }

        /* Warp-level reduction using __shfl_down_sync */
        unsigned int mask = 0xffffffffu;
        for (int offset = 16; offset > 0; offset >>= 1) {
            my_prod += __shfl_down_sync(mask, my_prod, offset);
        }

        /* Cross-warp reduction: 128 threads = 4 warps */
        __shared__ float warp_sums[4];
        const unsigned int warp_id = tid >> 5u;
        const unsigned int lane    = tid & 31u;
        if (lane == 0u) {
            warp_sums[warp_id] = my_prod;
        }
        __syncthreads();

        if (tid == 0u) {
            float total = warp_sums[0] + warp_sums[1] + warp_sums[2] + warp_sums[3];
            all_scores[(unsigned long long)q_head * max_seq + pos_t] = total * inv_sqrt_hd;
        }
        __syncthreads();
    }
}

/* =========================================================================
   Kernel 6 — batched_softmax
   Per-head numerically-stable softmax (in-place).

   Grid:  (n_q, 1, 1)
   Block: (256, 1, 1)

   Three-pass approach:
     Pass 1: find max (parallel reduction)
     Pass 2: compute exp(x − max) and sum
     Pass 3: normalise by sum
   ========================================================================= */
extern "C" __global__ void batched_softmax(
    float*       __restrict__ all_scores,
    unsigned int n_q,
    unsigned int max_seq,
    const unsigned int* __restrict__ d_pos_seqlen
) {
    const unsigned int seq_len = d_pos_seqlen[1];
    const unsigned int tgpig  = blockIdx.x;
    const unsigned int tid    = threadIdx.x;
    const unsigned int tg_size = blockDim.x;
    if (tgpig >= n_q) return;

    float* __restrict__ scores = all_scores + (unsigned long long)tgpig * max_seq;

    __shared__ float shared[256];

    /* Pass 1: max */
    float local_max = -3.402823e+38f;
    for (unsigned int i = tid; i < seq_len; i += tg_size) {
        float v = scores[i];
        if (v > local_max) local_max = v;
    }
    shared[tid] = local_max;
    __syncthreads();
    for (unsigned int s = tg_size >> 1u; s > 0u; s >>= 1u) {
        if (tid < s) {
            if (shared[tid + s] > shared[tid]) shared[tid] = shared[tid + s];
        }
        __syncthreads();
    }
    float gmax = shared[0];
    __syncthreads();

    /* Pass 2: exp + sum */
    float local_sum = 0.0f;
    for (unsigned int i = tid; i < seq_len; i += tg_size) {
        float e = expf(scores[i] - gmax);
        scores[i] = e;
        local_sum += e;
    }
    shared[tid] = local_sum;
    __syncthreads();
    for (unsigned int s = tg_size >> 1u; s > 0u; s >>= 1u) {
        if (tid < s) shared[tid] += shared[tid + s];
        __syncthreads();
    }
    float gsum = shared[0];
    __syncthreads();

    /* Pass 3: normalise */
    float inv_sum = (gsum > 0.0f) ? (1.0f / gsum) : 0.0f;
    for (unsigned int i = tid; i < seq_len; i += tg_size) {
        scores[i] *= inv_sum;
    }
}

/* =========================================================================
   Kernel 7 — batched_attn_weighted_sum
   Weighted sum of V vectors: output[d] = Σ_t scores[t] × V[t][d].

   Grid:  (ceil(head_dim/64), n_q, 1)
   Block: (64, 1, 1)

   Each thread accumulates one output dimension d for one Q head.
   GQA mapping: kv_head = q_head / heads_per_group.
   V is stored in FP16 in the KV cache.
   ========================================================================= */
extern "C" __global__ void batched_attn_weighted_sum(
    const float*          __restrict__ all_scores,
    const unsigned short* __restrict__ v_cache,
    float*                __restrict__ attn_out,
    unsigned int head_dim,
    unsigned int n_q,
    unsigned int n_kv,
    unsigned int heads_per_group,
    unsigned int max_seq,
    const unsigned int* __restrict__ d_pos_seqlen,
    unsigned int cache_layer_offset
) {
    const unsigned int seq_len = d_pos_seqlen[1];
    const unsigned int d      = blockIdx.x * blockDim.x + threadIdx.x;
    const unsigned int q_head = blockIdx.y;
    if (d >= head_dim || q_head >= n_q) return;

    const unsigned int kv_head = q_head / heads_per_group;
    const float* __restrict__ scores =
        all_scores + (unsigned long long)q_head * max_seq;
    const unsigned short* __restrict__ values =
        v_cache + (unsigned long long)cache_layer_offset
        + (unsigned long long)kv_head * max_seq * head_dim;

    float acc = 0.0f;
    for (unsigned int t = 0u; t < seq_len; t++) {
        acc += scores[t] * fast_fp16_to_float(values[(unsigned long long)t * head_dim + d]);
    }
    attn_out[(unsigned long long)q_head * head_dim + d] = acc;
}
"#;