ferrum-kernels 0.7.7

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
#include <metal_stdlib>
#include <metal_simdgroup>
using namespace metal;

// ── RMS Norm ────────────────────────────────────────────────────────────
// out[i] = (x[i] / rms) * weight[i], where rms = sqrt(mean(x^2) + eps)
// One threadgroup per row. Uses simd_sum for parallel reduction.

struct RmsNormParams {
    int dim;      // hidden dimension
    float eps;
};

kernel void rms_norm_f32(
    device const float* input   [[buffer(0)]],
    device const float* weight  [[buffer(1)]],
    device       float* output  [[buffer(2)]],
    constant RmsNormParams& p   [[buffer(3)]],
    uint  tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    uint  sgitg [[simdgroup_index_in_threadgroup]],
    uint  ntg   [[threads_per_threadgroup]],
    threadgroup float* shmem [[threadgroup(0)]])
{
    const int row = tgpig;
    device const float* x = input + row * p.dim;
    device float* y = output + row * p.dim;

    // Parallel sum of squares
    float sum_sq = 0.0f;
    for (int i = tiisg; i < p.dim; i += 32) {
        float v = x[i];
        sum_sq += v * v;
    }
    sum_sq = simd_sum(sum_sq);

    // Cross-simdgroup reduction if ntg > 32
    if (ntg > 32) {
        if (sgitg == 0) shmem[tiisg] = 0.0f;
        threadgroup_barrier(mem_flags::mem_threadgroup);
        if (tiisg == 0) shmem[sgitg] = sum_sq;
        threadgroup_barrier(mem_flags::mem_threadgroup);
        sum_sq = shmem[tiisg];
        sum_sq = simd_sum(sum_sq);
    }

    const float scale = 1.0f / sqrt(sum_sq / float(p.dim) + p.eps);

    // Apply norm and weight
    for (int i = tiisg; i < p.dim; i += 32) {
        y[i] = x[i] * scale * weight[i];
    }
}

// ── SiLU × Gate (SwiGLU) ───────────────────────────────────────────────
// out[i] = silu(gate[i]) * up[i]
// Simple element-wise, one thread per element.

struct SiluMulParams {
    int n; // total elements
};

kernel void silu_mul_f32(
    device const float* gate    [[buffer(0)]],
    device const float* up      [[buffer(1)]],
    device       float* output  [[buffer(2)]],
    constant SiluMulParams& p   [[buffer(3)]],
    uint tid [[thread_position_in_grid]])
{
    if (tid >= uint(p.n)) return;
    float g = gate[tid];
    float silu = g / (1.0f + exp(-g));
    output[tid] = silu * up[tid];
}

// ── Residual Add ────────────────────────────────────────────────────────
// out[i] = a[i] + b[i]

struct AddParams {
    int n;
};

kernel void add_f32(
    device const float* a       [[buffer(0)]],
    device const float* b       [[buffer(1)]],
    device       float* output  [[buffer(2)]],
    constant AddParams& p       [[buffer(3)]],
    uint tid [[thread_position_in_grid]])
{
    if (tid >= uint(p.n)) return;
    output[tid] = a[tid] + b[tid];
}

// ── Scalar-scaled add: dst[i] += scale * src[i] ─────────────────────────
// MoE expert combine — used in the per-(token, expert) accumulate loop
// where each contribution is weighted by a router-derived scalar.

struct ScaledAddParams {
    int   n;
    float scale;
};

kernel void scaled_add_inplace_f32(
    device       float* dst       [[buffer(0)]],
    device const float* src       [[buffer(1)]],
    constant ScaledAddParams& p   [[buffer(2)]],
    uint tid [[thread_position_in_grid]])
{
    if (tid >= uint(p.n)) return;
    dst[tid] += p.scale * src[tid];
}

// ── Element-wise Multiply (broadcast scale) ─────────────────────────────
// out[i] = a[i] * scale[i % scale_len]

struct MulScaleParams {
    int n;
    int scale_len;
};

kernel void mul_scale_f32(
    device const float* a       [[buffer(0)]],
    device const float* scale   [[buffer(1)]],
    device       float* output  [[buffer(2)]],
    constant MulScaleParams& p  [[buffer(3)]],
    uint tid [[thread_position_in_grid]])
{
    if (tid >= uint(p.n)) return;
    output[tid] = a[tid] * scale[tid % p.scale_len];
}

// ── Fused Scale-Add: out = a + b * scale ────────────────────────────────
// Single kernel replaces mul_scale + add (saves 1 dispatch + 1 encoder barrier)

kernel void fused_scale_add_f32(
    device const float* a       [[buffer(0)]],   // residual
    device const float* b       [[buffer(1)]],   // attn/mlp output
    device const float* scale   [[buffer(2)]],   // layer_scale vector
    device       float* output  [[buffer(3)]],
    constant MulScaleParams& p  [[buffer(4)]],   // reuse params struct
    uint tid [[thread_position_in_grid]])
{
    if (tid >= uint(p.n)) return;
    output[tid] = a[tid] + b[tid] * scale[tid % p.scale_len];
}

// ── Fused Residual-Add + RMSNorm ────────────────────────────────────────
// out_residual = a + b  (or a + b * scale if scale != NULL)
// out_norm = rms_norm(out_residual) * weight
// Saves 2 dispatches + 1 encoder barrier

struct FusedResNormParams {
    int tokens;
    int dim;
    float eps;
    int has_scale;   // 0 = no scale, 1 = apply scale to b before add
    int scale_len;
};

kernel void fused_residual_norm_f32(
    device const float* a       [[buffer(0)]],   // residual input
    device const float* b       [[buffer(1)]],   // attn/mlp output
    device const float* scale   [[buffer(2)]],   // layer_scale (or dummy if has_scale=0)
    device const float* weight  [[buffer(3)]],   // norm weight
    device       float* out_res [[buffer(4)]],   // residual output (a + b*scale)
    device       float* out_norm[[buffer(5)]],   // normalized output
    constant FusedResNormParams& p [[buffer(6)]],
    uint  tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    threadgroup float* shmem [[threadgroup(0)]])
{
    const int row = tgpig;
    if (row >= p.tokens) return;

    device const float* a_row = a + row * p.dim;
    device const float* b_row = b + row * p.dim;
    device float* res_row = out_res + row * p.dim;
    device float* norm_row = out_norm + row * p.dim;

    // Step 1: Residual add (with optional scale)
    float sum_sq = 0.0f;
    for (int i = tiisg; i < p.dim; i += 32) {
        float bv = b_row[i];
        if (p.has_scale) bv *= scale[i % p.scale_len];
        float r = a_row[i] + bv;
        res_row[i] = r;
        sum_sq += r * r;
    }

    // Step 2: RMSNorm reduction
    sum_sq = simd_sum(sum_sq);
    float inv = 1.0f / sqrt(sum_sq / float(p.dim) + p.eps);

    // Step 3: Apply norm + weight
    for (int i = tiisg; i < p.dim; i += 32) {
        norm_row[i] = res_row[i] * inv * weight[i];
    }
}

// ── GEMM: C = A @ B^T ──────────────────────────────────────────────────
// A: [M, K], B: [N, K] (row-major), C: [M, N]
// Uses simdgroup_matrix for 8x8 tiles.
// One threadgroup computes one 8×8 tile of C.

struct GemmParams {
    int M;
    int N;
    int K;
};

kernel void gemm_f32(
    device const float* A       [[buffer(0)]],
    device const float* B       [[buffer(1)]],
    device       float* C       [[buffer(2)]],
    constant GemmParams& p      [[buffer(3)]],
    uint3 tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]],
    threadgroup float* shmem_a  [[threadgroup(0)]],
    threadgroup float* shmem_b  [[threadgroup(1)]])
{
    // Each threadgroup handles one 8x8 output tile
    const int tile_row = tgpig.y * 8;
    const int tile_col = tgpig.x * 8;

    if (tile_row >= p.M || tile_col >= p.N) return;

    // For boundary tiles: use shared memory with zero-padding
    const int m_remain = min(8, p.M - tile_row);
    const int n_remain = min(8, p.N - tile_col);

    simdgroup_float8x8 acc = make_filled_simdgroup_matrix<float, 8>(0.0f);

    for (int kk = 0; kk < p.K; kk += 8) {
        const int k_remain = min(8, p.K - kk);

        // Load A tile into shared memory (zero-padded)
        for (int i = tiisg; i < 64; i += 32) {
            int r = i / 8, c = i % 8;
            shmem_a[i] = (tile_row + r < p.M && kk + c < p.K) ? A[(tile_row + r) * p.K + kk + c] : 0.0f;
        }
        // Load B tile (transposed) into shared memory
        for (int i = tiisg; i < 64; i += 32) {
            int r = i / 8, c = i % 8;  // r=K dim, c=N dim
            shmem_b[i] = (tile_col + c < p.N && kk + r < p.K) ? B[(tile_col + c) * p.K + kk + r] : 0.0f;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        simdgroup_float8x8 ma, mb;
        simdgroup_load(ma, shmem_a, 8);
        simdgroup_load(mb, shmem_b, 8);
        simdgroup_multiply_accumulate(acc, ma, mb, acc);

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // Store result (only valid elements)
    // Use shared memory to stage the write
    simdgroup_store(acc, shmem_a, 8);
    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (int i = tiisg; i < 64; i += 32) {
        int r = i / 8, c = i % 8;
        if (tile_row + r < p.M && tile_col + c < p.N) {
            C[(tile_row + r) * p.N + tile_col + c] = shmem_a[i];
        }
    }
}

// ── Argmax ─────────────────────────────────────────────────────────────
// Find index of maximum value in a float array.
// Input: data[n], Output: result[0] = argmax index (as uint32)
// Uses simd reduction for parallel max-finding.

struct ArgmaxParams {
    int n;
};

kernel void argmax_f32(
    device const float* data   [[buffer(0)]],
    device uint*        result [[buffer(1)]],
    constant ArgmaxParams& p  [[buffer(2)]],
    uint tid                   [[thread_index_in_threadgroup]],
    uint tg_size               [[threads_per_threadgroup]]
) {
    // Each thread finds local max
    float local_max = -INFINITY;
    int   local_idx = 0;
    for (int i = tid; i < p.n; i += tg_size) {
        float v = data[i];
        if (v > local_max) {
            local_max = v;
            local_idx = i;
        }
    }

    // Simd reduction to find global max
    for (int offset = 16; offset > 0; offset >>= 1) {
        float other_max = simd_shuffle_down(local_max, offset);
        int   other_idx = simd_shuffle_down(local_idx, offset);
        if (other_max > local_max) {
            local_max = other_max;
            local_idx = other_idx;
        }
    }

    // Threadgroup reduction (first thread of each simdgroup)
    threadgroup float tg_max[32];
    threadgroup int   tg_idx[32];
    int simd_id = tid / 32;
    int lane = tid % 32;
    if (lane == 0) {
        tg_max[simd_id] = local_max;
        tg_idx[simd_id] = local_idx;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Final reduction by thread 0
    if (tid == 0) {
        float best = tg_max[0];
        int   best_idx = tg_idx[0];
        int n_simd = (tg_size + 31) / 32;
        for (int s = 1; s < n_simd; s++) {
            if (tg_max[s] > best) {
                best = tg_max[s];
                best_idx = tg_idx[s];
            }
        }
        result[0] = (uint)best_idx;
    }
}

kernel void argmax_rows_f32(
    device const float* data   [[buffer(0)]],
    device uint*        result [[buffer(1)]],
    constant ArgmaxParams& p  [[buffer(2)]],
    uint2 tgpig                [[threadgroup_position_in_grid]],
    uint2 tpitg                [[thread_position_in_threadgroup]],
    uint2 tg_size2             [[threads_per_threadgroup]]
) {
    const uint row = tgpig.y;
    const uint tid = tpitg.x;
    const uint tg_size = tg_size2.x;
    device const float* row_data = data + row * uint(p.n);

    float local_max = -INFINITY;
    int local_idx = 0;
    for (int i = int(tid); i < p.n; i += int(tg_size)) {
        float v = row_data[i];
        if (v > local_max) {
            local_max = v;
            local_idx = i;
        }
    }

    for (int offset = 16; offset > 0; offset >>= 1) {
        float other_max = simd_shuffle_down(local_max, offset);
        int other_idx = simd_shuffle_down(local_idx, offset);
        if (other_max > local_max) {
            local_max = other_max;
            local_idx = other_idx;
        }
    }

    threadgroup float tg_max[32];
    threadgroup int tg_idx[32];
    int simd_id = int(tid / 32);
    int lane = int(tid % 32);
    if (lane == 0) {
        tg_max[simd_id] = local_max;
        tg_idx[simd_id] = local_idx;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    if (tid == 0) {
        float best = tg_max[0];
        int best_idx = tg_idx[0];
        int n_simd = (int(tg_size) + 31) / 32;
        for (int s = 1; s < n_simd; s++) {
            if (tg_max[s] > best) {
                best = tg_max[s];
                best_idx = tg_idx[s];
            }
        }
        result[row] = (uint)best_idx;
    }
}

// ── Embedding Lookup ───────────────────────────────────────────────────
// output[i] = table[index * dim + i]
// table: [vocab_size, dim], index: scalar, output: [dim]

struct EmbedParams {
    int dim;
};

kernel void embedding_lookup_f32(
    device const float* table  [[buffer(0)]],
    device const uint*  index  [[buffer(1)]],
    device float*       output [[buffer(2)]],
    constant EmbedParams& p    [[buffer(3)]],
    uint tid                   [[thread_position_in_grid]]
) {
    if ((int)tid < p.dim) {
        uint idx = index[0];
        output[tid] = table[idx * p.dim + tid];
    }
}

// ── Split fused QKV ────────────────────────────────────────────────────
// Input:  qkv [tokens, q_dim + 2*kv_dim] (row-major, Q|K|V interleaved per row)
// Output: q [tokens, q_dim], k [tokens, kv_dim], v [tokens, kv_dim]
// One thread per output element (dispatch q_dim + 2*kv_dim × tokens threads, one thread per element).

struct SplitQkvParams {
    int tokens;
    int q_dim;
    int kv_dim;
};

kernel void split_qkv_f32(
    device const float* qkv     [[buffer(0)]],
    device       float* q       [[buffer(1)]],
    device       float* k       [[buffer(2)]],
    device       float* v       [[buffer(3)]],
    constant SplitQkvParams& p  [[buffer(4)]],
    uint tid                    [[thread_position_in_grid]])
{
    const int qd  = p.q_dim;
    const int kd  = p.kv_dim;
    const int row_stride = qd + 2 * kd;
    const int total = p.tokens * row_stride;
    if ((int)tid >= total) return;

    const int t = (int)tid / row_stride;
    const int c = (int)tid % row_stride;
    const float val = qkv[(int)tid];

    if (c < qd) {
        q[t * qd + c] = val;
    } else if (c < qd + kd) {
        k[t * kd + (c - qd)] = val;
    } else {
        v[t * kd + (c - qd - kd)] = val;
    }
}

// ── LayerNorm (mean + variance) ─────────────────────────────────────────
// Distinct from RMSNorm — Bert / Clip / Whisper use full LayerNorm.
//   out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]

struct LayerNormParams {
    int dim;
    float eps;
};

kernel void layer_norm_f32(
    device const float* input   [[buffer(0)]],
    device const float* gamma   [[buffer(1)]],
    device const float* beta    [[buffer(2)]],
    device       float* output  [[buffer(3)]],
    constant LayerNormParams& p [[buffer(4)]],
    uint  tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]])
{
    const int row = (int)tgpig;
    device const float* x = input  + row * p.dim;
    device       float* y = output + row * p.dim;

    // Pass 1: mean.
    float sum = 0.0f;
    for (int i = tiisg; i < p.dim; i += 32) {
        sum += x[i];
    }
    sum = simd_sum(sum);
    float mean = sum / float(p.dim);

    // Pass 2: variance.
    float var_sum = 0.0f;
    for (int i = tiisg; i < p.dim; i += 32) {
        float d = x[i] - mean;
        var_sum += d * d;
    }
    var_sum = simd_sum(var_sum);
    float inv = 1.0f / sqrt(var_sum / float(p.dim) + p.eps);

    // Pass 3: write.
    for (int i = tiisg; i < p.dim; i += 32) {
        y[i] = (x[i] - mean) * inv * gamma[i] + beta[i];
    }
}

// ── GELU (erf-based, PyTorch default) ───────────────────────────────────
// Approximate erf via Abramowitz & Stegun 7.1.26 — accurate to ~1.5e-7.

struct GeluParams {
    int n;
};

static inline float erf_approx(float x) {
    const float sign = (x < 0.0f) ? -1.0f : 1.0f;
    x = fabs(x);
    const float t = 1.0f / (1.0f + 0.3275911f * x);
    const float y = 1.0f - (((((1.061405429f * t - 1.453152027f) * t) + 1.421413741f) * t
                            - 0.284496736f) * t + 0.254829592f) * t * exp(-x * x);
    return sign * y;
}

kernel void gelu_f32(
    device const float* input   [[buffer(0)]],
    device       float* output  [[buffer(1)]],
    constant GeluParams& p      [[buffer(2)]],
    uint tid [[thread_position_in_grid]])
{
    if ((int)tid >= p.n) return;
    const float x = input[tid];
    output[tid] = 0.5f * x * (1.0f + erf_approx(x * M_SQRT1_2_F));
}

// ── Broadcast bias add: data[r, c] += bias[c] ───────────────────────────

struct AddBiasParams {
    int rows;
    int cols;
};

kernel void add_bias_f32(
    device       float* data    [[buffer(0)]],
    device const float* bias    [[buffer(1)]],
    constant AddBiasParams& p   [[buffer(2)]],
    uint tid [[thread_position_in_grid]])
{
    const int total = p.rows * p.cols;
    if ((int)tid >= total) return;
    const int c = (int)tid % p.cols;
    data[tid] += bias[c];
}

// ── GEMV: m=1 GEMM specialization ───────────────────────────────────────
// C[1, N] = A[1, K] @ B[N, K]^T
// One threadgroup (1 simdgroup = 32 threads) per output column.
// K-dim reduction via simd_sum.
//
// Note: tried tile_n=4 (A-vector reuse across 4 cols), turned out slower on
// Qwen3-0.6B (30.2 vs 32.5 tok/s). Bottleneck is not A-side DRAM traffic,
// it's B-side reads (vocab×hidden for lm_head dwarfs everything else).

kernel void gemv_f32(
    device const float* A       [[buffer(0)]],   // [1, K]
    device const float* B       [[buffer(1)]],   // [N, K]
    device       float* C       [[buffer(2)]],   // [1, N]
    constant GemmParams& p      [[buffer(3)]],
    uint  tgpig [[threadgroup_position_in_grid]],
    uint  tiisg [[thread_index_in_simdgroup]])
{
    const int nj = (int)tgpig;
    if (nj >= p.N) return;

    device const float* b_row = B + nj * p.K;
    const int K = p.K;

    float acc = 0.0f;

    // Vectorized float4 path when K is aligned and large enough.
    // Each thread consumes 4 floats per step = 128-bit load, 4 FMAs.
    if ((K & 3) == 0) {
        device const float4* a4 = (device const float4*)A;
        device const float4* b4 = (device const float4*)b_row;
        const int K4 = K >> 2;
        for (int k4 = (int)tiisg; k4 < K4; k4 += 32) {
            float4 a = a4[k4];
            float4 b = b4[k4];
            acc += dot(a, b);
        }
    } else {
        for (int k = (int)tiisg; k < K; k += 32) {
            acc += A[k] * b_row[k];
        }
    }
    acc = simd_sum(acc);

    if (tiisg == 0) {
        C[nj] = acc;
    }
}

// ── Fused SiLU × Up with gate_up split ──────────────────────────────────
// Input:  gate_up [tokens, 2*im]   (gate = first im, up = second im per row)
// Output: out     [tokens, im]     out[t, i] = silu(gate[t, i]) * up[t, i]

struct SiluMulSplitParams {
    int tokens;
    int im;
};

kernel void silu_mul_split_f32(
    device const float* gate_up    [[buffer(0)]],
    device       float* out        [[buffer(1)]],
    constant SiluMulSplitParams& p [[buffer(2)]],
    uint tid                       [[thread_position_in_grid]])
{
    const int total = p.tokens * p.im;
    if ((int)tid >= total) return;
    const int t = (int)tid / p.im;
    const int i = (int)tid % p.im;
    const float g = gate_up[t * 2 * p.im + i];
    const float u = gate_up[t * 2 * p.im + p.im + i];
    const float silu = g / (1.0f + exp(-g));
    out[(int)tid] = silu * u;
}