Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    r#"//
122// Auto-generated by ForgeLLM Metal codegen.
123// Metal Shading Language compute kernels for transformer inference.
124//
125// Optimized with simdgroup cooperative reductions, shared memory vector
126// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
127// lane, and fast:: math intrinsics for Apple Silicon throughput.
128//
129
130#include <metal_stdlib>
131#include <metal_simdgroup_matrix>
132using namespace metal;
133
134// ── Constants ───────────────────────────────────────────────────────────
135// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
136// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
137// per shared memory load vs 4-row, improving ILP and reducing launches.
138constant constexpr uint SIMDGROUPS_PER_TG = 8;
139constant constexpr uint ROWS_PER_SIMDGROUP = 8;
140constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
141
142// ── matmul_vec ──────────────────────────────────────────────────────────
143// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
144// Uses simdgroup cooperative dot product with shared memory vector caching
145// and float4 vectorized loads. Each simdgroup processes 8 rows for better
146// shared memory reuse (8x vector reuse per load) and instruction-level
147// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
148kernel void matmul_vec(
149    device const float* matrix [[buffer(0)]],
150    device const float* vector [[buffer(1)]],
151    device float* output       [[buffer(2)]],
152    constant uint& rows        [[buffer(3)]],
153    constant uint& cols        [[buffer(4)]],
154    uint tgid [[threadgroup_position_in_grid]],
155    uint tid [[thread_index_in_threadgroup]],
156    uint simd_lane [[thread_index_in_simdgroup]],
157    uint simd_id [[simdgroup_index_in_threadgroup]])
158{
159    // Cooperatively load vector into threadgroup shared memory
160    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
161    for (uint i = tid; i < cols; i += 256) {
162        vec_tile[i] = vector[i];
163    }
164    threadgroup_barrier(mem_flags::mem_threadgroup);
165
166    // Each simdgroup handles 8 consecutive rows
167    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
168    if (row_base >= rows) return;
169
170    uint base0 = row_base * cols;
171    uint base1 = (row_base + 1) * cols;
172    uint base2 = (row_base + 2) * cols;
173    uint base3 = (row_base + 3) * cols;
174    uint base4 = (row_base + 4) * cols;
175    uint base5 = (row_base + 5) * cols;
176    uint base6 = (row_base + 6) * cols;
177    uint base7 = (row_base + 7) * cols;
178
179    // float4 vectorized accumulation across 8 rows
180    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
181    float4 sum4_0 = float4(0.0f);
182    float4 sum4_1 = float4(0.0f);
183    float4 sum4_2 = float4(0.0f);
184    float4 sum4_3 = float4(0.0f);
185    float4 sum4_4 = float4(0.0f);
186    float4 sum4_5 = float4(0.0f);
187    float4 sum4_6 = float4(0.0f);
188    float4 sum4_7 = float4(0.0f);
189
190    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
191        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
192        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
193        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
194        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
195        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
196        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
197        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
198        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
199        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
200    }
201
202    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
203    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
204    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
205    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
206    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
207    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
208    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
209    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
210
211    // Handle remaining elements (cols not divisible by 128)
212    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
213        float vv = vec_tile[j];
214        sum0 += matrix[base0 + j] * vv;
215        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
216        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
217        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
218        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
219        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
220        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
221        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
222    }
223
224    // Simdgroup hardware warp-level reduction
225    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
226    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
227    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
228    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
229
230    // Only first lane writes the results
231    if (simd_lane == 0) {
232        if (row_base     < rows) output[row_base]     = sum0;
233        if (row_base + 1 < rows) output[row_base + 1] = sum1;
234        if (row_base + 2 < rows) output[row_base + 2] = sum2;
235        if (row_base + 3 < rows) output[row_base + 3] = sum3;
236        if (row_base + 4 < rows) output[row_base + 4] = sum4;
237        if (row_base + 5 < rows) output[row_base + 5] = sum5;
238        if (row_base + 6 < rows) output[row_base + 6] = sum6;
239        if (row_base + 7 < rows) output[row_base + 7] = sum7;
240    }
241}
242
243// ── rms_norm ────────────────────────────────────────────────────────────
244// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
245// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
246// via shared memory for minimal synchronization overhead.
247kernel void rms_norm(
248    device const float* input   [[buffer(0)]],
249    device const float* weight  [[buffer(1)]],
250    device float* output        [[buffer(2)]],
251    constant uint& n            [[buffer(3)]],
252    constant float& eps         [[buffer(4)]],
253    uint tid [[thread_index_in_threadgroup]])
254{
255    // Each thread accumulates partial sum-of-squares
256    float sum_sq = 0.0f;
257    for (uint i = tid; i < n; i += 256) {
258        float v = input[i];
259        sum_sq += v * v;
260    }
261
262    // Simdgroup-level reduction (hardware warp sum)
263    sum_sq = simd_sum(sum_sq);
264
265    // Cross-simdgroup reduction via shared memory
266    threadgroup float shared[8];
267    uint simd_id = tid / 32;
268    uint simd_lane = tid % 32;
269    if (simd_lane == 0) {
270        shared[simd_id] = sum_sq;
271    }
272    threadgroup_barrier(mem_flags::mem_threadgroup);
273
274    // First thread computes final inverse RMS
275    if (tid == 0) {
276        float total = 0.0f;
277        for (uint i = 0; i < 8; i++) {
278            total += shared[i];
279        }
280        shared[0] = fast::rsqrt(total / float(n) + eps);
281    }
282    threadgroup_barrier(mem_flags::mem_threadgroup);
283
284    float inv_rms = shared[0];
285
286    // Normalize
287    for (uint i = tid; i < n; i += 256) {
288        output[i] = input[i] * inv_rms * weight[i];
289    }
290}
291
292// ── rope ────────────────────────────────────────────────────────────────
293// Rotary Position Embedding applied in-place.
294// Each thread handles one (head, pair) combination.
295kernel void rope(
296    device float* data        [[buffer(0)]],
297    constant uint& num_heads  [[buffer(1)]],
298    constant uint& head_dim   [[buffer(2)]],
299    constant uint& pos        [[buffer(3)]],
300    constant float& theta     [[buffer(4)]],
301    uint id [[thread_position_in_grid]])
302{
303    uint half_dim = head_dim / 2;
304    uint total_pairs = num_heads * half_dim;
305    if (id >= total_pairs) return;
306
307    uint h = id / half_dim;
308    uint i = id % half_dim;
309    uint off = h * head_dim;
310
311    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
312    float angle = float(pos) * freq;
313    float c = cos(angle);
314    float s = sin(angle);
315
316    float x0 = data[off + 2 * i];
317    float x1 = data[off + 2 * i + 1];
318    data[off + 2 * i]     = x0 * c - x1 * s;
319    data[off + 2 * i + 1] = x0 * s + x1 * c;
320}
321
322// ── softmax ─────────────────────────────────────────────────────────────
323// Numerically stable softmax over a 1-D array.
324// Single-threadgroup kernel with cooperative reduction.
325kernel void softmax(
326    device float* data       [[buffer(0)]],
327    constant uint& n         [[buffer(1)]],
328    uint tid [[thread_index_in_threadgroup]],
329    uint tg_size [[threads_per_threadgroup]])
330{
331    threadgroup float shared_val[256];
332
333    // Pass 1: find max
334    float local_max = -INFINITY;
335    for (uint i = tid; i < n; i += tg_size) {
336        local_max = max(local_max, data[i]);
337    }
338    shared_val[tid] = local_max;
339    threadgroup_barrier(mem_flags::mem_threadgroup);
340
341    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
342        if (tid < stride) {
343            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
344        }
345        threadgroup_barrier(mem_flags::mem_threadgroup);
346    }
347    float max_val = shared_val[0];
348    threadgroup_barrier(mem_flags::mem_threadgroup);
349
350    // Pass 2: exp and sum
351    float local_sum = 0.0f;
352    for (uint i = tid; i < n; i += tg_size) {
353        float e = fast::exp(data[i] - max_val);
354        data[i] = e;
355        local_sum += e;
356    }
357    shared_val[tid] = local_sum;
358    threadgroup_barrier(mem_flags::mem_threadgroup);
359
360    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
361        if (tid < stride) {
362            shared_val[tid] += shared_val[tid + stride];
363        }
364        threadgroup_barrier(mem_flags::mem_threadgroup);
365    }
366    float inv_sum = 1.0f / shared_val[0];
367    threadgroup_barrier(mem_flags::mem_threadgroup);
368
369    // Pass 3: normalize
370    for (uint i = tid; i < n; i += tg_size) {
371        data[i] *= inv_sum;
372    }
373}
374
375// ── silu_mul ────────────────────────────────────────────────────────────
376// Fused SiLU activation * element-wise multiply:
377//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
378kernel void silu_mul(
379    device const float* gate [[buffer(0)]],
380    device const float* up   [[buffer(1)]],
381    device float* output     [[buffer(2)]],
382    constant uint& n         [[buffer(3)]],
383    uint id [[thread_position_in_grid]])
384{
385    if (id >= n) return;
386    float g = gate[id];
387    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
388}
389
390// ── silu_mul_fused ─────────────────────────────────────────────────────
391// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
392//   gate = gate_up[0..n], up = gate_up[n..2*n]
393//   output[i] = silu(gate_up[i]) * gate_up[n + i]
394kernel void silu_mul_fused(
395    device const float* gate_up [[buffer(0)]],
396    device float* output        [[buffer(1)]],
397    constant uint& n            [[buffer(2)]],
398    uint id [[thread_position_in_grid]])
399{
400    if (id >= n) return;
401    float g = gate_up[id];
402    float u = gate_up[n + id];
403    output[id] = (g / (1.0f + fast::exp(-g))) * u;
404}
405
406// ── elementwise_add ─────────────────────────────────────────────────────
407// Residual connection: output[i] = a[i] + b[i]
408kernel void elementwise_add(
409    device const float* a  [[buffer(0)]],
410    device const float* b  [[buffer(1)]],
411    device float* output   [[buffer(2)]],
412    constant uint& n       [[buffer(3)]],
413    uint id [[thread_position_in_grid]])
414{
415    if (id >= n) return;
416    output[id] = a[id] + b[id];
417}
418
419// ── copy_buffer ─────────────────────────────────────────────────────────
420// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
421// transitions. Used for KV cache updates and embedding lookup.
422kernel void copy_buffer(
423    device const float* src [[buffer(0)]],
424    device float* dst       [[buffer(1)]],
425    constant uint& count    [[buffer(2)]],
426    uint id [[thread_position_in_grid]])
427{
428    if (id < count) dst[id] = src[id];
429}
430
431// ── copy_offset ─────────────────────────────────────────────────────────
432// Copy with source offset (in floats). Used for embedding table lookup
433// where we need to copy a specific row from a large table.
434kernel void copy_offset(
435    device const float* src     [[buffer(0)]],
436    device float* dst           [[buffer(1)]],
437    constant uint& src_offset   [[buffer(2)]],  // in floats
438    constant uint& count        [[buffer(3)]],
439    uint id [[thread_position_in_grid]])
440{
441    if (id < count) dst[id] = src[src_offset + id];
442}
443
444// ── add_inplace ─────────────────────────────────────────────────────────
445// In-place residual connection: a[i] += b[i]
446// Avoids a separate blit copy for residual add, reducing encoder overhead.
447kernel void add_inplace(
448    device float* a        [[buffer(0)]],
449    device const float* b  [[buffer(1)]],
450    constant uint& n       [[buffer(2)]],
451    uint id [[thread_position_in_grid]])
452{
453    if (id >= n) return;
454    a[id] += b[id];
455}
456
457// ── matmul_vec_q8 ─────────────────────────────────────────────────────
458// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
459// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
460// Operates directly on quantized weights to halve memory bandwidth vs f32,
461// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
462//
463// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
464// because int8->float conversion doubles register demand.  Fully unrolled
465// inner loop with float4 vector loads from shared memory eliminates loop
466// overhead and enables better instruction scheduling.
467// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
468constant constexpr uint Q8_ROWS_PER_SG = 4;
469constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
470
471// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
472// doubles ALU work, so the same register budget applies).
473constant constexpr uint Q4_ROWS_PER_SG = 4;
474constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
475
476kernel void matmul_vec_q8(
477    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
478    device const float* vector   [[buffer(1)]],  // f32 input
479    device float* output         [[buffer(2)]],
480    constant uint& rows          [[buffer(3)]],
481    constant uint& cols          [[buffer(4)]],  // number of elements per row
482    uint tgid [[threadgroup_position_in_grid]],
483    uint tid [[thread_index_in_threadgroup]],
484    uint simd_lane [[thread_index_in_simdgroup]],
485    uint simd_id [[simdgroup_index_in_threadgroup]])
486{
487    // Load vector into shared memory
488    threadgroup float vec_tile[VEC_TILE_SIZE];
489    for (uint i = tid; i < cols; i += 256) {
490        vec_tile[i] = vector[i];
491    }
492    threadgroup_barrier(mem_flags::mem_threadgroup);
493
494    // Each simdgroup handles 4 consecutive rows (lower register pressure)
495    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
496    if (row_base >= rows) return;
497
498    // Q8_0: each block is 34 bytes for 32 elements
499    uint blocks_per_row = cols / 32;
500    uint row_bytes = blocks_per_row * 34;
501
502    // Pointers to each row's Q8_0 data
503    device const uchar* r0 = matrix + row_base * row_bytes;
504    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
505    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
506    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
507
508    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
509
510    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
511        uint bb = blk * 34;
512        uint vb = blk * 32;
513
514        // Prefetch all 4 scales
515        float sc0 = float(*(device const half*)(r0 + bb));
516        float sc1 = float(*(device const half*)(r1 + bb));
517        float sc2 = float(*(device const half*)(r2 + bb));
518        float sc3 = float(*(device const half*)(r3 + bb));
519
520        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
521        // Q8_0 block layout where the int8 data starts at offset +2 from a
522        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
523        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
524        // reduction in memory transactions. Metal's char16/packed_char16 are
525        // reserved types and packed_*int4 require >=4-byte alignment which
526        // this layout does not provide, so packed_short4 is the widest valid
527        // vectorized load.
528        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
529        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
530        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
531        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
532
533        // Load all 8 float4 vector values for this 32-element block from shared memory
534        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
535        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
536        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
537        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
538        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
539        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
540        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
541        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
542
543        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
544        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
545        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
546            short4 _s = short4(SHORT4); \
547            char2 _a = as_type<char2>(_s.x); \
548            char2 _b = as_type<char2>(_s.y); \
549            char2 _c = as_type<char2>(_s.z); \
550            char2 _d = as_type<char2>(_s.w); \
551            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
552            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
553        }
554
555        float4 f0, f1;
556        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
557
558        // Row 0: 4 short4 loads cover 32 int8 weights
559        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
560        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
561        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
562        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
563
564        // Row 1
565        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
566        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
567        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
568        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
569
570        // Row 2
571        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
572        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
573        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
574        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
575
576        // Row 3
577        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
578        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
579        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
580        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
581
582        #undef Q8_UNPACK8
583
584        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
585    }
586
587    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
588    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
589
590    if (simd_lane == 0) {
591        if (row_base     < rows) output[row_base]     = sum0;
592        if (row_base + 1 < rows) output[row_base + 1] = sum1;
593        if (row_base + 2 < rows) output[row_base + 2] = sum2;
594        if (row_base + 3 < rows) output[row_base + 3] = sum3;
595    }
596}
597
598// ── matmul_vec_q4 ─────────────────────────────────────────────────────
599// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
600// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
601// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
602// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
603//
604// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
605// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
606kernel void matmul_vec_q4(
607    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
608    device const float* vector   [[buffer(1)]],  // f32 input
609    device float* output         [[buffer(2)]],
610    constant uint& rows          [[buffer(3)]],
611    constant uint& cols          [[buffer(4)]],  // number of elements per row
612    uint tgid [[threadgroup_position_in_grid]],
613    uint tid [[thread_index_in_threadgroup]],
614    uint simd_lane [[thread_index_in_simdgroup]],
615    uint simd_id [[simdgroup_index_in_threadgroup]])
616{
617    // Load vector into shared memory
618    threadgroup float vec_tile[VEC_TILE_SIZE];
619    for (uint i = tid; i < cols; i += 256) {
620        vec_tile[i] = vector[i];
621    }
622    threadgroup_barrier(mem_flags::mem_threadgroup);
623
624    // Each simdgroup handles 4 consecutive rows
625    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
626    if (row_base >= rows) return;
627
628    // Q4_0: each block is 18 bytes for 32 elements
629    uint blocks_per_row = cols / 32;
630    uint row_bytes = blocks_per_row * 18;
631
632    // Pointers to each row's Q4_0 data
633    device const uchar* r0 = matrix + row_base * row_bytes;
634    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
635    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
636    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
637
638    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
639
640    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
641        uint bb = blk * 18;
642        uint vb = blk * 32;
643
644        // Prefetch all 4 scales
645        float sc0 = float(*(device const half*)(r0 + bb));
646        float sc1 = float(*(device const half*)(r1 + bb));
647        float sc2 = float(*(device const half*)(r2 + bb));
648        float sc3 = float(*(device const half*)(r3 + bb));
649
650        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
651        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
652        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
653        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
654        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
655
656        // Load 8 float4 vector values for 32 elements from shared memory
657        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
658        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
659        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
660        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
661        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
662        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
663        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
664        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
665        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
666
667        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
668        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
669        float bd0=0, bd1=0, bd2=0, bd3=0;
670        uchar4 b;
671
672        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
673        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
674                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
675                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
676                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
677        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
678                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
679                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
680                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
681        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
682                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
683                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
684                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
685        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
686                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
687                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
688                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
689
690        // Row 1
691        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
692                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
693                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
694                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
695        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
696                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
697                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
698                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
699        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
700                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
701                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
702                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
703        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
704                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
705                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
706                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
707
708        // Row 2
709        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
710                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
711                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
712                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
713        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
714                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
715                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
716                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
717        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
718                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
719                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
720                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
721        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
722                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
723                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
724                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
725
726        // Row 3
727        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
728                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
729                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
730                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
731        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
732                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
733                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
734                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
735        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
736                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
737                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
738                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
739        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
740                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
741                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
742                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
743
744        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
745    }
746
747    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
748    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
749
750    if (simd_lane == 0) {
751        if (row_base     < rows) output[row_base]     = sum0;
752        if (row_base + 1 < rows) output[row_base + 1] = sum1;
753        if (row_base + 2 < rows) output[row_base + 2] = sum2;
754        if (row_base + 3 < rows) output[row_base + 3] = sum3;
755    }
756}
757
758// ── attention ───────────────────────────────────────────────────────────
759// Single-query attention with simdgroup cooperative reductions.
760// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
761// with simd_max/simd_sum reductions, then weighted sum of V.
762// Each threadgroup handles one head with 256 threads (8 simdgroups).
763//
764// Buffers:
765//   q:       [num_heads * head_dim]       current query
766//   k_cache: [max_seq_len * num_kv_heads * head_dim]
767//   v_cache: [max_seq_len * num_kv_heads * head_dim]
768//   output:  [num_heads * head_dim]
769kernel void attention(
770    device const float* q        [[buffer(0)]],
771    device const float* k_cache  [[buffer(1)]],
772    device const float* v_cache  [[buffer(2)]],
773    device float* output         [[buffer(3)]],
774    constant uint& seq_len       [[buffer(4)]],
775    constant uint& num_heads     [[buffer(5)]],
776    constant uint& num_kv_heads  [[buffer(6)]],
777    constant uint& head_dim      [[buffer(7)]],
778    uint tgid [[threadgroup_position_in_grid]],
779    uint tid [[thread_index_in_threadgroup]],
780    uint simd_lane [[thread_index_in_simdgroup]],
781    uint simd_id [[simdgroup_index_in_threadgroup]])
782{
783    uint head = tgid;
784    if (head >= num_heads) return;
785    uint kv_head = head / (num_heads / num_kv_heads);
786
787    uint q_off = head * head_dim;
788
789    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
790    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
791    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
792    // most generation steps have short effective context.
793    threadgroup float scores[2048];  // max seq_len for generation phase
794
795    for (uint s = simd_id; s < seq_len; s += 8) {
796        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
797        float dot = 0.0;
798        for (uint d = simd_lane; d < head_dim; d += 32) {
799            dot += q[q_off + d] * k_cache[k_off + d];
800        }
801        dot = simd_sum(dot);
802        if (simd_lane == 0) {
803            scores[s] = dot * fast::rsqrt(float(head_dim));
804        }
805    }
806    threadgroup_barrier(mem_flags::mem_threadgroup);
807
808    // Step 2: Softmax over scores (cooperative)
809    // Find max
810    float local_max = -INFINITY;
811    for (uint s = tid; s < seq_len; s += 256) {
812        local_max = max(local_max, scores[s]);
813    }
814    local_max = simd_max(local_max);
815    threadgroup float shared_max[8];
816    if (simd_lane == 0) shared_max[simd_id] = local_max;
817    threadgroup_barrier(mem_flags::mem_threadgroup);
818    if (tid == 0) {
819        float m = shared_max[0];
820        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
821        shared_max[0] = m;
822    }
823    threadgroup_barrier(mem_flags::mem_threadgroup);
824    float max_val = shared_max[0];
825
826    // Exp and sum
827    float local_sum = 0.0;
828    for (uint s = tid; s < seq_len; s += 256) {
829        scores[s] = fast::exp(scores[s] - max_val);
830        local_sum += scores[s];
831    }
832    local_sum = simd_sum(local_sum);
833    threadgroup float shared_sum[8];
834    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
835    threadgroup_barrier(mem_flags::mem_threadgroup);
836    if (tid == 0) {
837        float total = 0.0;
838        for (uint i = 0; i < 8; i++) total += shared_sum[i];
839        shared_sum[0] = 1.0 / total;
840    }
841    threadgroup_barrier(mem_flags::mem_threadgroup);
842    float inv_sum = shared_sum[0];
843
844    for (uint s = tid; s < seq_len; s += 256) {
845        scores[s] *= inv_sum;
846    }
847    threadgroup_barrier(mem_flags::mem_threadgroup);
848
849    // Step 3: Weighted sum of V: output = scores · V
850    // Each thread handles a range of head_dim dimensions.
851    // Process 4 sequence positions at a time for better ILP and reduced
852    // loop overhead (float4 score gather, 4 V loads per iteration).
853    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
854    uint v_stride = num_kv_heads * head_dim;
855    for (uint d = tid; d < head_dim; d += 256) {
856        float acc = 0.0;
857        uint v_base = kv_head * head_dim + d;
858        for (uint s = 0; s < seq_len4; s += 4) {
859            float sc0 = scores[s];
860            float sc1 = scores[s + 1];
861            float sc2 = scores[s + 2];
862            float sc3 = scores[s + 3];
863            acc += sc0 * v_cache[s * v_stride + v_base]
864                 + sc1 * v_cache[(s+1) * v_stride + v_base]
865                 + sc2 * v_cache[(s+2) * v_stride + v_base]
866                 + sc3 * v_cache[(s+3) * v_stride + v_base];
867        }
868        for (uint s = seq_len4; s < seq_len; s++) {
869            acc += scores[s] * v_cache[s * v_stride + v_base];
870        }
871        output[q_off + d] = acc;
872    }
873}
874
875// ── Batched prefill kernels ────────────────────────────────────────────
876// These kernels process M input vectors against the same weight matrix
877// in a single dispatch, converting mat-vec into mat-mat for better GPU
878// utilization during prompt prefill.
879
880// ── rms_norm_batch ─────────────────────────────────────────────────────
881// RMS normalization for a batch of vectors.
882// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
883// Grid: M threadgroups (one per token).
884kernel void rms_norm_batch(
885    device const float* input   [[buffer(0)]],  // [M, n]
886    device const float* weight  [[buffer(1)]],  // [n]
887    device float* output        [[buffer(2)]],  // [M, n]
888    constant uint& n            [[buffer(3)]],
889    constant float& eps         [[buffer(4)]],
890    constant uint& num_tokens   [[buffer(5)]],
891    uint tgid [[threadgroup_position_in_grid]],
892    uint tid [[thread_index_in_threadgroup]])
893{
894    if (tgid >= num_tokens) return;
895
896    uint base = tgid * n;
897
898    float sum_sq = 0.0f;
899    for (uint i = tid; i < n; i += 256) {
900        float v = input[base + i];
901        sum_sq += v * v;
902    }
903
904    sum_sq = simd_sum(sum_sq);
905
906    threadgroup float shared[8];
907    uint simd_id = tid / 32;
908    uint simd_lane = tid % 32;
909    if (simd_lane == 0) {
910        shared[simd_id] = sum_sq;
911    }
912    threadgroup_barrier(mem_flags::mem_threadgroup);
913
914    if (tid == 0) {
915        float total = 0.0f;
916        for (uint i = 0; i < 8; i++) {
917            total += shared[i];
918        }
919        shared[0] = fast::rsqrt(total / float(n) + eps);
920    }
921    threadgroup_barrier(mem_flags::mem_threadgroup);
922
923    float inv_rms = shared[0];
924
925    for (uint i = tid; i < n; i += 256) {
926        output[base + i] = input[base + i] * inv_rms * weight[i];
927    }
928}
929
930// ── rope_batch ─────────────────────────────────────────────────────────
931// Rotary Position Embedding for a batch of vectors with different positions.
932// data layout: [M, num_heads * head_dim], positions: [M]
933// Each thread handles one (token, head, pair) combination.
934kernel void rope_batch(
935    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
936    constant uint& num_heads     [[buffer(1)]],
937    constant uint& head_dim      [[buffer(2)]],
938    device const uint* positions  [[buffer(3)]],  // [M] position per token
939    constant float& theta        [[buffer(4)]],
940    constant uint& num_tokens    [[buffer(5)]],
941    uint id [[thread_position_in_grid]])
942{
943    uint half_dim = head_dim / 2;
944    uint pairs_per_token = num_heads * half_dim;
945    uint total = num_tokens * pairs_per_token;
946    if (id >= total) return;
947
948    uint token = id / pairs_per_token;
949    uint rem = id % pairs_per_token;
950    uint h = rem / half_dim;
951    uint i = rem % half_dim;
952    uint off = token * (num_heads * head_dim) + h * head_dim;
953
954    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
955    float angle = float(positions[token]) * freq;
956    float c = cos(angle);
957    float s = sin(angle);
958
959    float x0 = data[off + 2 * i];
960    float x1 = data[off + 2 * i + 1];
961    data[off + 2 * i]     = x0 * c - x1 * s;
962    data[off + 2 * i + 1] = x0 * s + x1 * c;
963}
964
965// ── silu_mul_fused_batch ───────────────────────────────────────────────
966// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
967// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
968kernel void silu_mul_fused_batch(
969    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
970    device float* output        [[buffer(1)]],  // [M, n]
971    constant uint& n            [[buffer(2)]],
972    constant uint& num_tokens   [[buffer(3)]],
973    uint id [[thread_position_in_grid]])
974{
975    uint total = num_tokens * n;
976    if (id >= total) return;
977    uint token = id / n;
978    uint i = id % n;
979    uint gu_base = token * 2 * n;
980    float g = gate_up[gu_base + i];
981    float u = gate_up[gu_base + n + i];
982    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
983}
984
985// ── add_inplace_batch ──────────────────────────────────────────────────
986// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
987kernel void add_inplace_batch(
988    device float* a        [[buffer(0)]],  // [M * n]
989    device const float* b  [[buffer(1)]],  // [M * n]
990    constant uint& total   [[buffer(2)]],  // M * n
991    uint id [[thread_position_in_grid]])
992{
993    if (id >= total) return;
994    a[id] += b[id];
995}
996
997// ── copy_embedding_batch ───────────────────────────────────────────────
998// Copy M embedding rows from embedding table to a contiguous batch buffer.
999// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1000kernel void copy_embedding_batch(
1001    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1002    device float* output        [[buffer(1)]],  // [M, dim]
1003    device const uint* tokens   [[buffer(2)]],  // [M]
1004    constant uint& dim          [[buffer(3)]],
1005    constant uint& num_tokens   [[buffer(4)]],
1006    uint id [[thread_position_in_grid]])
1007{
1008    uint total = num_tokens * dim;
1009    if (id >= total) return;
1010    uint token_idx = id / dim;
1011    uint d = id % dim;
1012    output[id] = embed[tokens[token_idx] * dim + d];
1013}
1014
1015// ── matmul_vec_batch ───────────────────────────────────────────────────
1016// Batched matrix-vector multiply: process M input vectors against the same
1017// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1018// Each threadgroup handles one (token, row_group) pair.
1019kernel void matmul_vec_batch(
1020    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1021    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1022    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1023    constant uint& num_tokens   [[buffer(3)]],  // M
1024    constant uint& rows         [[buffer(4)]],
1025    constant uint& cols         [[buffer(5)]],
1026    uint tgid [[threadgroup_position_in_grid]],
1027    uint tid [[thread_index_in_threadgroup]],
1028    uint simd_lane [[thread_index_in_simdgroup]],
1029    uint simd_id [[simdgroup_index_in_threadgroup]])
1030{
1031    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1032    uint token = tgid / row_tgs;
1033    uint tg_in_token = tgid % row_tgs;
1034    if (token >= num_tokens) return;
1035
1036    // Load this token's input vector into shared memory
1037    threadgroup float vec_tile[VEC_TILE_SIZE];
1038    device const float* input = inputs + token * cols;
1039    for (uint i = tid; i < cols; i += 256) {
1040        vec_tile[i] = input[i];
1041    }
1042    threadgroup_barrier(mem_flags::mem_threadgroup);
1043
1044    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1045    if (row_base >= rows) return;
1046
1047    uint base0 = row_base * cols;
1048    uint base1 = (row_base + 1) * cols;
1049    uint base2 = (row_base + 2) * cols;
1050    uint base3 = (row_base + 3) * cols;
1051    uint base4 = (row_base + 4) * cols;
1052    uint base5 = (row_base + 5) * cols;
1053    uint base6 = (row_base + 6) * cols;
1054    uint base7 = (row_base + 7) * cols;
1055
1056    uint cols_vec4 = cols & ~127u;
1057    float4 sum4_0 = float4(0.0f);
1058    float4 sum4_1 = float4(0.0f);
1059    float4 sum4_2 = float4(0.0f);
1060    float4 sum4_3 = float4(0.0f);
1061    float4 sum4_4 = float4(0.0f);
1062    float4 sum4_5 = float4(0.0f);
1063    float4 sum4_6 = float4(0.0f);
1064    float4 sum4_7 = float4(0.0f);
1065
1066    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1067        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1068        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1069        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1070        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1071        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1072        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1073        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1074        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1075        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1076    }
1077
1078    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1079    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1080    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1081    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1082    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1083    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1084    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1085    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1086
1087    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1088        float vv = vec_tile[j];
1089        sum0 += matrix[base0 + j] * vv;
1090        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1091        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1092        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1093        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1094        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1095        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1096        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1097    }
1098
1099    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1100    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1101    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1102    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1103
1104    device float* output = outputs + token * rows;
1105    if (simd_lane == 0) {
1106        if (row_base     < rows) output[row_base]     = sum0;
1107        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1108        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1109        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1110        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1111        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1112        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1113        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1114    }
1115}
1116
1117// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1118// Batched Q8_0 matrix-vector multiply for M input vectors.
1119// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1120kernel void matmul_vec_q8_batch(
1121    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1122    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1123    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1124    constant uint& num_tokens    [[buffer(3)]],  // M
1125    constant uint& rows          [[buffer(4)]],
1126    constant uint& cols          [[buffer(5)]],
1127    uint tgid [[threadgroup_position_in_grid]],
1128    uint tid [[thread_index_in_threadgroup]],
1129    uint simd_lane [[thread_index_in_simdgroup]],
1130    uint simd_id [[simdgroup_index_in_threadgroup]])
1131{
1132    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1133    uint token = tgid / row_tgs;
1134    uint tg_in_token = tgid % row_tgs;
1135    if (token >= num_tokens) return;
1136
1137    // Load this token's input vector into shared memory
1138    threadgroup float vec_tile[VEC_TILE_SIZE];
1139    device const float* input = inputs + token * cols;
1140    for (uint i = tid; i < cols; i += 256) {
1141        vec_tile[i] = input[i];
1142    }
1143    threadgroup_barrier(mem_flags::mem_threadgroup);
1144
1145    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1146    if (row_base >= rows) return;
1147
1148    uint blocks_per_row = cols / 32;
1149    uint row_bytes = blocks_per_row * 34;
1150
1151    device const uchar* r0 = matrix + row_base * row_bytes;
1152    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1153    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1154    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1155
1156    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1157
1158    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1159        uint bb = blk * 34;
1160        uint vb = blk * 32;
1161
1162        float sc0 = float(*(device const half*)(r0 + bb));
1163        float sc1 = float(*(device const half*)(r1 + bb));
1164        float sc2 = float(*(device const half*)(r2 + bb));
1165        float sc3 = float(*(device const half*)(r3 + bb));
1166
1167        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1168        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1169        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1170        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1171        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1172        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1173
1174        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1175        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1176        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1177        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1178        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1179        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1180        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1181        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1182
1183        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1184            short4 _s = short4(SHORT4); \
1185            char2 _a = as_type<char2>(_s.x); \
1186            char2 _b = as_type<char2>(_s.y); \
1187            char2 _c = as_type<char2>(_s.z); \
1188            char2 _d = as_type<char2>(_s.w); \
1189            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1190            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1191        }
1192
1193        float4 f0, f1;
1194        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1195
1196        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1197        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1198        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1199        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1200
1201        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1202        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1203        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1204        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1205
1206        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1207        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1208        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1209        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1210
1211        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1212        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1213        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1214        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1215
1216        #undef Q8_UNPACK8
1217
1218        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1219    }
1220
1221    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1222    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1223
1224    device float* output = outputs + token * rows;
1225    if (simd_lane == 0) {
1226        if (row_base     < rows) output[row_base]     = sum0;
1227        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1228        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1229        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1230    }
1231}
1232
1233// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1234// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1235// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1236// the Q8_0 weight blocks are fetched once from device memory and reused for
1237// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1238// per-token dispatch).
1239//
1240// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1241// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1242// with simd_sum.  Token vectors are read directly from device memory inside
1243// the block loop (not cached in shared memory) so intermediate_size up to
1244// 8192 fits without spilling threadgroup memory.
1245constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1246
1247kernel void matmul_q8_gemm_batch(
1248    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1249    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1250    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1251    constant uint& num_tokens    [[buffer(3)]],  // M
1252    constant uint& rows          [[buffer(4)]],
1253    constant uint& cols          [[buffer(5)]],
1254    uint2 tgid [[threadgroup_position_in_grid]],
1255    uint tid [[thread_index_in_threadgroup]],
1256    uint simd_lane [[thread_index_in_simdgroup]],
1257    uint simd_id [[simdgroup_index_in_threadgroup]])
1258{
1259    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1260    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1261    if (row_base >= rows || tok_base >= num_tokens) return;
1262
1263    // How many tokens in this tile are valid?
1264    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1265
1266    uint blocks_per_row = cols / 32;
1267    uint row_bytes = blocks_per_row * 34;
1268
1269    device const uchar* r0 = matrix + row_base * row_bytes;
1270    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1271    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1272    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1273
1274    // Accumulators: 4 tokens × 4 rows per simdgroup.
1275    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1276    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1277    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1278    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1279
1280    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1281        uint bb = blk * 34;
1282        uint vb = blk * 32;
1283
1284        // ── Load weight data ONCE per block (reused across all tokens) ──
1285        float sc0 = float(*(device const half*)(r0 + bb));
1286        float sc1 = float(*(device const half*)(r1 + bb));
1287        float sc2 = float(*(device const half*)(r2 + bb));
1288        float sc3 = float(*(device const half*)(r3 + bb));
1289
1290        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1291        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1292        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1293        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1294
1295        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1296            short4 _s = short4(SHORT4); \
1297            char2 _a = as_type<char2>(_s.x); \
1298            char2 _b = as_type<char2>(_s.y); \
1299            char2 _c = as_type<char2>(_s.z); \
1300            char2 _d = as_type<char2>(_s.w); \
1301            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1302            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1303        }
1304
1305        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1306        // registers for the duration of the block and are dotted against
1307        // every token's vector tile.
1308        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1309        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1310        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1311        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1312
1313        Q8_UNPACK8(d0[0], w0_0, w0_1);
1314        Q8_UNPACK8(d0[1], w0_2, w0_3);
1315        Q8_UNPACK8(d0[2], w0_4, w0_5);
1316        Q8_UNPACK8(d0[3], w0_6, w0_7);
1317
1318        Q8_UNPACK8(d1[0], w1_0, w1_1);
1319        Q8_UNPACK8(d1[1], w1_2, w1_3);
1320        Q8_UNPACK8(d1[2], w1_4, w1_5);
1321        Q8_UNPACK8(d1[3], w1_6, w1_7);
1322
1323        Q8_UNPACK8(d2[0], w2_0, w2_1);
1324        Q8_UNPACK8(d2[1], w2_2, w2_3);
1325        Q8_UNPACK8(d2[2], w2_4, w2_5);
1326        Q8_UNPACK8(d2[3], w2_6, w2_7);
1327
1328        Q8_UNPACK8(d3[0], w3_0, w3_1);
1329        Q8_UNPACK8(d3[1], w3_2, w3_3);
1330        Q8_UNPACK8(d3[2], w3_4, w3_5);
1331        Q8_UNPACK8(d3[3], w3_6, w3_7);
1332
1333        #undef Q8_UNPACK8
1334
1335        // ── For each token, read vector and accumulate against shared weights ──
1336        // Token 0 (always valid: tok_count >= 1).
1337        {
1338            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1339            float4 v0 = *(device const float4*)(a0);
1340            float4 v1 = *(device const float4*)(a0 + 4);
1341            float4 v2 = *(device const float4*)(a0 + 8);
1342            float4 v3 = *(device const float4*)(a0 + 12);
1343            float4 v4 = *(device const float4*)(a0 + 16);
1344            float4 v5 = *(device const float4*)(a0 + 20);
1345            float4 v6 = *(device const float4*)(a0 + 24);
1346            float4 v7 = *(device const float4*)(a0 + 28);
1347            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1348                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1349            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1350                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1351            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1352                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1353            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1354                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1355            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1356        }
1357        // Token 1
1358        if (tok_count > 1) {
1359            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1360            float4 v0 = *(device const float4*)(a1);
1361            float4 v1 = *(device const float4*)(a1 + 4);
1362            float4 v2 = *(device const float4*)(a1 + 8);
1363            float4 v3 = *(device const float4*)(a1 + 12);
1364            float4 v4 = *(device const float4*)(a1 + 16);
1365            float4 v5 = *(device const float4*)(a1 + 20);
1366            float4 v6 = *(device const float4*)(a1 + 24);
1367            float4 v7 = *(device const float4*)(a1 + 28);
1368            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1369                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1370            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1371                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1372            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1373                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1374            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1375                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1376            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1377        }
1378        // Token 2
1379        if (tok_count > 2) {
1380            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1381            float4 v0 = *(device const float4*)(a2);
1382            float4 v1 = *(device const float4*)(a2 + 4);
1383            float4 v2 = *(device const float4*)(a2 + 8);
1384            float4 v3 = *(device const float4*)(a2 + 12);
1385            float4 v4 = *(device const float4*)(a2 + 16);
1386            float4 v5 = *(device const float4*)(a2 + 20);
1387            float4 v6 = *(device const float4*)(a2 + 24);
1388            float4 v7 = *(device const float4*)(a2 + 28);
1389            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1390                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1391            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1392                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1393            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1394                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1395            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1396                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1397            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1398        }
1399        // Token 3
1400        if (tok_count > 3) {
1401            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1402            float4 v0 = *(device const float4*)(a3);
1403            float4 v1 = *(device const float4*)(a3 + 4);
1404            float4 v2 = *(device const float4*)(a3 + 8);
1405            float4 v3 = *(device const float4*)(a3 + 12);
1406            float4 v4 = *(device const float4*)(a3 + 16);
1407            float4 v5 = *(device const float4*)(a3 + 20);
1408            float4 v6 = *(device const float4*)(a3 + 24);
1409            float4 v7 = *(device const float4*)(a3 + 28);
1410            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1411                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1412            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1413                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1414            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1415                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1416            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1417                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1418            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1419        }
1420    }
1421
1422    // simdgroup reduction
1423    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1424    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1425    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1426    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1427
1428    if (simd_lane == 0) {
1429        device float* o0 = outputs + (tok_base + 0) * rows;
1430        if (row_base     < rows) o0[row_base]     = s00;
1431        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1432        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1433        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1434
1435        if (tok_count > 1) {
1436            device float* o1 = outputs + (tok_base + 1) * rows;
1437            if (row_base     < rows) o1[row_base]     = s10;
1438            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1439            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1440            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1441        }
1442        if (tok_count > 2) {
1443            device float* o2 = outputs + (tok_base + 2) * rows;
1444            if (row_base     < rows) o2[row_base]     = s20;
1445            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1446            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1447            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1448        }
1449        if (tok_count > 3) {
1450            device float* o3 = outputs + (tok_base + 3) * rows;
1451            if (row_base     < rows) o3[row_base]     = s30;
1452            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1453            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1454            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1455        }
1456    }
1457}
1458
1459// ── matmul_q8_mma ──────────────────────────────────────────────────────
1460// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1461// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1462// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1463// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1464//
1465// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1466// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1467// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1468// dequantized into threadgroup memory once per block and reused by all
1469// simdgroups in the tile.
1470//
1471// Assumptions (verified in the dispatch helper, falls back otherwise):
1472//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1473//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1474//   * num_tokens may be any value; partial row at the tile boundary is handled
1475//     via a scratch copy path.
1476constant constexpr uint MMA_TOK_TILE = 16;
1477constant constexpr uint MMA_ROW_TILE = 16;
1478
1479kernel void matmul_q8_mma(
1480    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1481    device const float* inputs   [[buffer(1)]],  // [M, cols]
1482    device float* outputs        [[buffer(2)]],  // [M, rows]
1483    constant uint& num_tokens    [[buffer(3)]],
1484    constant uint& rows          [[buffer(4)]],
1485    constant uint& cols          [[buffer(5)]],
1486    uint2 tgid [[threadgroup_position_in_grid]],
1487    uint tid [[thread_index_in_threadgroup]],
1488    uint simd_id [[simdgroup_index_in_threadgroup]])
1489{
1490    uint row_base = tgid.x * MMA_ROW_TILE;
1491    uint tok_base = tgid.y * MMA_TOK_TILE;
1492    if (row_base >= rows || tok_base >= num_tokens) return;
1493
1494    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1495    threadgroup float w_tile[MMA_ROW_TILE * 32];
1496    threadgroup float t_tile[MMA_TOK_TILE * 32];
1497
1498    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1499    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1500    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1501
1502    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1503
1504    uint blocks_per_row = cols / 32;
1505    uint row_bytes = blocks_per_row * 34;
1506
1507    for (uint blk = 0; blk < blocks_per_row; blk++) {
1508        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1509        // 512 floats / 128 threads = 4 floats per thread.
1510        {
1511            uint base = tid * 4;
1512            for (uint ii = 0; ii < 4; ii++) {
1513                uint idx = base + ii;
1514                uint r = idx / 32;
1515                uint k = idx % 32;
1516                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1517                float sc = float(*(device const half*)rp);
1518                int ival = int(*(device const int8_t*)(rp + 2 + k));
1519                w_tile[r * 32 + k] = float(ival) * sc;
1520            }
1521        }
1522
1523        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1524        {
1525            uint base = tid * 4;
1526            for (uint ii = 0; ii < 4; ii++) {
1527                uint idx = base + ii;
1528                uint m = idx / 32;
1529                uint k = idx % 32;
1530                uint tok = tok_base + m;
1531                t_tile[m * 32 + k] = (tok < num_tokens)
1532                    ? inputs[tok * cols + blk * 32 + k]
1533                    : 0.0f;
1534            }
1535        }
1536
1537        threadgroup_barrier(mem_flags::mem_threadgroup);
1538
1539        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1540        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1541        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1542        // C[m, r] += A[m, k] * B[k, r]
1543        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1544            simdgroup_matrix<float, 8, 8> A, B;
1545            simdgroup_load(A,
1546                t_tile + sg_tok_base * 32 + k_sub * 8,
1547                32,
1548                ulong2(0, 0),
1549                false);
1550            simdgroup_load(B,
1551                w_tile + sg_row_base * 32 + k_sub * 8,
1552                32,
1553                ulong2(0, 0),
1554                true);
1555            simdgroup_multiply_accumulate(C, A, B, C);
1556        }
1557
1558        threadgroup_barrier(mem_flags::mem_threadgroup);
1559    }
1560
1561    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1562    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1563    uint out_tok = tok_base + sg_tok_base;
1564    uint out_row = row_base + sg_row_base;
1565    bool full_tok = (out_tok + 8 <= num_tokens);
1566    if (full_tok) {
1567        // Fast path: entire 8×8 sub-tile is in-bounds.
1568        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1569    } else if (out_tok < num_tokens) {
1570        // Partial row at the last token tile: stage in per-simdgroup scratch
1571        // and scalar-copy the valid rows.
1572        threadgroup float scratch[4 * 64];
1573        simdgroup_store(C, scratch + simd_id * 64, 8);
1574        simdgroup_barrier(mem_flags::mem_threadgroup);
1575        uint lane = tid % 32;
1576        if (lane == 0) {
1577            uint valid = num_tokens - out_tok;  // 1..7
1578            for (uint m = 0; m < valid; m++) {
1579                device float* dst = outputs + (out_tok + m) * rows + out_row;
1580                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1581                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1582                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1583            }
1584        }
1585    }
1586}
1587
1588// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1589// Larger-tile variant of matmul_q8_mma for long-context prefill.
1590//
1591// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1592// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1593// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1594//
1595//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1596//     output sub-tiles (tok, row):
1597//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1598//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1599//
1600// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1601// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1602// kernel — and halves the number of threadgroups vs the 16×16 tile.
1603//
1604// Assumptions (verified in dispatch helper, fallback otherwise):
1605//   * cols % 32 == 0
1606//   * rows % 32 == 0
1607constant constexpr uint MMA32_TOK_TILE = 32;
1608constant constexpr uint MMA32_ROW_TILE = 32;
1609
1610kernel void matmul_q8_mma32(
1611    device const uchar* matrix   [[buffer(0)]],
1612    device const float* inputs   [[buffer(1)]],
1613    device float* outputs        [[buffer(2)]],
1614    constant uint& num_tokens    [[buffer(3)]],
1615    constant uint& rows          [[buffer(4)]],
1616    constant uint& cols          [[buffer(5)]],
1617    uint2 tgid [[threadgroup_position_in_grid]],
1618    uint tid [[thread_index_in_threadgroup]],
1619    uint simd_id [[simdgroup_index_in_threadgroup]])
1620{
1621    uint row_base = tgid.x * MMA32_ROW_TILE;
1622    uint tok_base = tgid.y * MMA32_TOK_TILE;
1623    if (row_base >= rows || tok_base >= num_tokens) return;
1624
1625    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1626    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1627    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1628
1629    uint sg_tok_idx  = simd_id / 2;      // 0..3
1630    uint sg_row_half = simd_id % 2;      // 0..1
1631    uint sg_tok_base = sg_tok_idx * 8;
1632    uint sg_row_base_a = sg_row_half * 16 + 0;
1633    uint sg_row_base_b = sg_row_half * 16 + 8;
1634
1635    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1636    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1637
1638    uint blocks_per_row = cols / 32;
1639    uint row_bytes = blocks_per_row * 34;
1640
1641    for (uint blk = 0; blk < blocks_per_row; blk++) {
1642        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1643        {
1644            uint base = tid * 4;
1645            for (uint ii = 0; ii < 4; ii++) {
1646                uint idx = base + ii;
1647                uint r = idx / 32;
1648                uint k = idx % 32;
1649                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1650                float sc = float(*(device const half*)rp);
1651                int ival = int(*(device const int8_t*)(rp + 2 + k));
1652                w_tile[r * 32 + k] = float(ival) * sc;
1653            }
1654        }
1655
1656        // Cooperative token tile load.
1657        {
1658            uint base = tid * 4;
1659            for (uint ii = 0; ii < 4; ii++) {
1660                uint idx = base + ii;
1661                uint m = idx / 32;
1662                uint k = idx % 32;
1663                uint tok = tok_base + m;
1664                t_tile[m * 32 + k] = (tok < num_tokens)
1665                    ? inputs[tok * cols + blk * 32 + k]
1666                    : 0.0f;
1667            }
1668        }
1669
1670        threadgroup_barrier(mem_flags::mem_threadgroup);
1671
1672        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1673        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1674            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1675            simdgroup_load(A,
1676                t_tile + sg_tok_base * 32 + k_sub * 8,
1677                32,
1678                ulong2(0, 0),
1679                false);
1680            simdgroup_load(B_a,
1681                w_tile + sg_row_base_a * 32 + k_sub * 8,
1682                32,
1683                ulong2(0, 0),
1684                true);
1685            simdgroup_load(B_b,
1686                w_tile + sg_row_base_b * 32 + k_sub * 8,
1687                32,
1688                ulong2(0, 0),
1689                true);
1690            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1691            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1692        }
1693
1694        threadgroup_barrier(mem_flags::mem_threadgroup);
1695    }
1696
1697    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1698    // (verified in dispatch), so full simdgroup_store is safe for the row
1699    // dimension; only the last token tile may be partial.
1700    uint out_tok = tok_base + sg_tok_base;
1701    uint out_row_a = row_base + sg_row_base_a;
1702    uint out_row_b = row_base + sg_row_base_b;
1703    bool full_tok = (out_tok + 8 <= num_tokens);
1704    if (full_tok) {
1705        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1706        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1707    } else if (out_tok < num_tokens) {
1708        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1709        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1710        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1711        simdgroup_barrier(mem_flags::mem_threadgroup);
1712        uint lane = tid % 32;
1713        if (lane == 0) {
1714            uint valid = num_tokens - out_tok;  // 1..7
1715            for (uint m = 0; m < valid; m++) {
1716                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1717                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1718                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1719                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1720                for (uint j = 0; j < 8; j++) {
1721                    dst_a[j] = src_a[j];
1722                    dst_b[j] = src_b[j];
1723                }
1724            }
1725        }
1726    }
1727}
1728
1729// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1730// FP16 threadgroup-tile variant of matmul_q8_mma32.
1731//
1732// Stores dequantized weights and token inputs as `half` in threadgroup
1733// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1734// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1735// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1736// representation preserves the full quantized dynamic range.  Token
1737// activations stay numerically safe because the subsequent
1738// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1739//
1740// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1741// accumulators each.  Primary win vs mma32 is occupancy at moderate
1742// prefill lengths where the GPU is wave-starved.
1743kernel void matmul_q8_mma32_h(
1744    device const uchar* matrix   [[buffer(0)]],
1745    device const float* inputs   [[buffer(1)]],
1746    device float* outputs        [[buffer(2)]],
1747    constant uint& num_tokens    [[buffer(3)]],
1748    constant uint& rows          [[buffer(4)]],
1749    constant uint& cols          [[buffer(5)]],
1750    uint2 tgid [[threadgroup_position_in_grid]],
1751    uint tid [[thread_index_in_threadgroup]],
1752    uint simd_id [[simdgroup_index_in_threadgroup]])
1753{
1754    uint row_base = tgid.x * MMA32_ROW_TILE;
1755    uint tok_base = tgid.y * MMA32_TOK_TILE;
1756    if (row_base >= rows || tok_base >= num_tokens) return;
1757
1758    // 32×32 half tiles — 2 KB each, 4 KB total.
1759    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1760    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1761
1762    uint sg_tok_idx  = simd_id / 2;
1763    uint sg_row_half = simd_id % 2;
1764    uint sg_tok_base = sg_tok_idx * 8;
1765    uint sg_row_base_a = sg_row_half * 16 + 0;
1766    uint sg_row_base_b = sg_row_half * 16 + 8;
1767
1768    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1769    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1770
1771    uint blocks_per_row = cols / 32;
1772    uint row_bytes = blocks_per_row * 34;
1773
1774    for (uint blk = 0; blk < blocks_per_row; blk++) {
1775        // Cooperative weight dequantization to FP16.
1776        {
1777            uint base = tid * 4;
1778            for (uint ii = 0; ii < 4; ii++) {
1779                uint idx = base + ii;
1780                uint r = idx / 32;
1781                uint k = idx % 32;
1782                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1783                float sc = float(*(device const half*)rp);
1784                int ival = int(*(device const int8_t*)(rp + 2 + k));
1785                w_tile[r * 32 + k] = half(float(ival) * sc);
1786            }
1787        }
1788
1789        // Cooperative token tile load (f32 → f16 narrowing).
1790        {
1791            uint base = tid * 4;
1792            for (uint ii = 0; ii < 4; ii++) {
1793                uint idx = base + ii;
1794                uint m = idx / 32;
1795                uint k = idx % 32;
1796                uint tok = tok_base + m;
1797                t_tile[m * 32 + k] = (tok < num_tokens)
1798                    ? half(inputs[tok * cols + blk * 32 + k])
1799                    : half(0);
1800            }
1801        }
1802
1803        threadgroup_barrier(mem_flags::mem_threadgroup);
1804
1805        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1806            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1807            simdgroup_load(A,
1808                t_tile + sg_tok_base * 32 + k_sub * 8,
1809                32,
1810                ulong2(0, 0),
1811                false);
1812            simdgroup_load(B_a,
1813                w_tile + sg_row_base_a * 32 + k_sub * 8,
1814                32,
1815                ulong2(0, 0),
1816                true);
1817            simdgroup_load(B_b,
1818                w_tile + sg_row_base_b * 32 + k_sub * 8,
1819                32,
1820                ulong2(0, 0),
1821                true);
1822            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1823            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1824        }
1825
1826        threadgroup_barrier(mem_flags::mem_threadgroup);
1827    }
1828
1829    uint out_tok = tok_base + sg_tok_base;
1830    uint out_row_a = row_base + sg_row_base_a;
1831    uint out_row_b = row_base + sg_row_base_b;
1832    bool full_tok = (out_tok + 8 <= num_tokens);
1833    if (full_tok) {
1834        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1835        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1836    } else if (out_tok < num_tokens) {
1837        threadgroup float scratch[8 * 2 * 64];
1838        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1839        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1840        simdgroup_barrier(mem_flags::mem_threadgroup);
1841        uint lane = tid % 32;
1842        if (lane == 0) {
1843            uint valid = num_tokens - out_tok;
1844            for (uint m = 0; m < valid; m++) {
1845                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1846                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1847                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1848                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1849                for (uint j = 0; j < 8; j++) {
1850                    dst_a[j] = src_a[j];
1851                    dst_b[j] = src_b[j];
1852                }
1853            }
1854        }
1855    }
1856}
1857
1858// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1859// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1860//
1861// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1862// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1863//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1864//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1865// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1866//
1867// Per K_sub iteration: load two A fragments and two B fragments, then run
1868// **four** MMA instructions reusing A_top with both B's and A_bot with
1869// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1870// 2-accumulator kernel and halves the thread count per threadgroup (128
1871// threads), which often improves occupancy on Apple GPUs where the
1872// concurrent-thread budget is the tighter limit than shared-memory size.
1873kernel void matmul_q8_mma32_h4(
1874    device const uchar* matrix   [[buffer(0)]],
1875    device const float* inputs   [[buffer(1)]],
1876    device float* outputs        [[buffer(2)]],
1877    constant uint& num_tokens    [[buffer(3)]],
1878    constant uint& rows          [[buffer(4)]],
1879    constant uint& cols          [[buffer(5)]],
1880    uint2 tgid [[threadgroup_position_in_grid]],
1881    uint tid [[thread_index_in_threadgroup]],
1882    uint simd_id [[simdgroup_index_in_threadgroup]])
1883{
1884    uint row_base = tgid.x * MMA32_ROW_TILE;
1885    uint tok_base = tgid.y * MMA32_TOK_TILE;
1886    if (row_base >= rows || tok_base >= num_tokens) return;
1887
1888    // 32×32 FP16 tiles, 4 KB total.
1889    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1890    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1891
1892    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1893    uint sg_tok_q = simd_id / 2;   // 0..1
1894    uint sg_row_q = simd_id % 2;   // 0..1
1895    uint sg_tok_base = sg_tok_q * 16;
1896    uint sg_row_base = sg_row_q * 16;
1897
1898    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1899    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1900    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1901    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1902
1903    uint blocks_per_row = cols / 32;
1904    uint row_bytes = blocks_per_row * 34;
1905
1906    for (uint blk = 0; blk < blocks_per_row; blk++) {
1907        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1908        {
1909            uint base = tid * 8;
1910            for (uint ii = 0; ii < 8; ii++) {
1911                uint idx = base + ii;
1912                uint r = idx / 32;
1913                uint k = idx % 32;
1914                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1915                float sc = float(*(device const half*)rp);
1916                int ival = int(*(device const int8_t*)(rp + 2 + k));
1917                w_tile[r * 32 + k] = half(float(ival) * sc);
1918            }
1919        }
1920
1921        // Cooperative token tile load.
1922        {
1923            uint base = tid * 8;
1924            for (uint ii = 0; ii < 8; ii++) {
1925                uint idx = base + ii;
1926                uint m = idx / 32;
1927                uint k = idx % 32;
1928                uint tok = tok_base + m;
1929                t_tile[m * 32 + k] = (tok < num_tokens)
1930                    ? half(inputs[tok * cols + blk * 32 + k])
1931                    : half(0);
1932            }
1933        }
1934
1935        threadgroup_barrier(mem_flags::mem_threadgroup);
1936
1937        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1938        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1939            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1940            simdgroup_load(A_top,
1941                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1942                32,
1943                ulong2(0, 0),
1944                false);
1945            simdgroup_load(A_bot,
1946                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1947                32,
1948                ulong2(0, 0),
1949                false);
1950            simdgroup_load(B_lo,
1951                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1952                32,
1953                ulong2(0, 0),
1954                true);
1955            simdgroup_load(B_hi,
1956                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1957                32,
1958                ulong2(0, 0),
1959                true);
1960            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1961            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1962            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1963            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1964        }
1965
1966        threadgroup_barrier(mem_flags::mem_threadgroup);
1967    }
1968
1969    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
1970    uint out_tok_top = tok_base + sg_tok_base + 0;
1971    uint out_tok_bot = tok_base + sg_tok_base + 8;
1972    uint out_row_lo  = row_base + sg_row_base + 0;
1973    uint out_row_hi  = row_base + sg_row_base + 8;
1974    bool full = (out_tok_bot + 8 <= num_tokens);
1975    if (full) {
1976        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1977        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1978        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1979        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1980    } else {
1981        // Partial-token fallback via per-simdgroup scratch.
1982        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
1983        uint sg_base = simd_id * 256;
1984        simdgroup_store(C_00, scratch + sg_base +   0, 8);
1985        simdgroup_store(C_01, scratch + sg_base +  64, 8);
1986        simdgroup_store(C_10, scratch + sg_base + 128, 8);
1987        simdgroup_store(C_11, scratch + sg_base + 192, 8);
1988        simdgroup_barrier(mem_flags::mem_threadgroup);
1989        uint lane = tid % 32;
1990        if (lane == 0) {
1991            for (uint m = 0; m < 8; m++) {
1992                uint t_top = out_tok_top + m;
1993                if (t_top < num_tokens) {
1994                    device float* dst0 = outputs + t_top * rows + out_row_lo;
1995                    device float* dst1 = outputs + t_top * rows + out_row_hi;
1996                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
1997                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
1998                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
1999                }
2000                uint t_bot = out_tok_bot + m;
2001                if (t_bot < num_tokens) {
2002                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2003                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2004                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2005                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2006                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2007                }
2008            }
2009        }
2010    }
2011}
2012
2013// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2014// All-half MMA variant of matmul_q8_mma32_h4.
2015//
2016// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2017// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2018// rate (dual-issue FMA), so if Q8_0 precision holds through half
2019// accumulation this kernel can double the effective FLOP throughput on
2020// matmul-bound prefill.
2021//
2022// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2023// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2024// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2025// precision on extreme values, but the inputs are already quantized so the
2026// per-product error floor is higher than the half-precision rounding error.
2027// We verify correctness on 135M / 1B / 3B before enabling.
2028kernel void matmul_q8_mma32_hh4(
2029    device const uchar* matrix   [[buffer(0)]],
2030    device const float* inputs   [[buffer(1)]],
2031    device float* outputs        [[buffer(2)]],
2032    constant uint& num_tokens    [[buffer(3)]],
2033    constant uint& rows          [[buffer(4)]],
2034    constant uint& cols          [[buffer(5)]],
2035    uint2 tgid [[threadgroup_position_in_grid]],
2036    uint tid [[thread_index_in_threadgroup]],
2037    uint simd_id [[simdgroup_index_in_threadgroup]])
2038{
2039    uint row_base = tgid.x * MMA32_ROW_TILE;
2040    uint tok_base = tgid.y * MMA32_TOK_TILE;
2041    if (row_base >= rows || tok_base >= num_tokens) return;
2042
2043    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2044    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2045
2046    uint sg_tok_q = simd_id / 2;
2047    uint sg_row_q = simd_id % 2;
2048    uint sg_tok_base = sg_tok_q * 16;
2049    uint sg_row_base = sg_row_q * 16;
2050
2051    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2052    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2053    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2054    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2055
2056    uint blocks_per_row = cols / 32;
2057    uint row_bytes = blocks_per_row * 34;
2058
2059    for (uint blk = 0; blk < blocks_per_row; blk++) {
2060        {
2061            uint base = tid * 8;
2062            for (uint ii = 0; ii < 8; ii++) {
2063                uint idx = base + ii;
2064                uint r = idx / 32;
2065                uint k = idx % 32;
2066                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2067                float sc = float(*(device const half*)rp);
2068                int ival = int(*(device const int8_t*)(rp + 2 + k));
2069                w_tile[r * 32 + k] = half(float(ival) * sc);
2070            }
2071        }
2072        {
2073            uint base = tid * 8;
2074            for (uint ii = 0; ii < 8; ii++) {
2075                uint idx = base + ii;
2076                uint m = idx / 32;
2077                uint k = idx % 32;
2078                uint tok = tok_base + m;
2079                t_tile[m * 32 + k] = (tok < num_tokens)
2080                    ? half(inputs[tok * cols + blk * 32 + k])
2081                    : half(0);
2082            }
2083        }
2084        threadgroup_barrier(mem_flags::mem_threadgroup);
2085
2086        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2087            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2088            simdgroup_load(A_top,
2089                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2090                32, ulong2(0, 0), false);
2091            simdgroup_load(A_bot,
2092                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2093                32, ulong2(0, 0), false);
2094            simdgroup_load(B_lo,
2095                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2096                32, ulong2(0, 0), true);
2097            simdgroup_load(B_hi,
2098                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2099                32, ulong2(0, 0), true);
2100            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2101            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2102            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2103            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2104        }
2105
2106        threadgroup_barrier(mem_flags::mem_threadgroup);
2107    }
2108
2109    // Store half accumulators via scratch (must widen to f32 for device output).
2110    uint out_tok_top = tok_base + sg_tok_base + 0;
2111    uint out_tok_bot = tok_base + sg_tok_base + 8;
2112    uint out_row_lo  = row_base + sg_row_base + 0;
2113    uint out_row_hi  = row_base + sg_row_base + 8;
2114
2115    threadgroup half scratch[4 * 4 * 64];
2116    uint sg_base = simd_id * 256;
2117    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2118    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2119    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2120    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2121    simdgroup_barrier(mem_flags::mem_threadgroup);
2122    uint lane = tid % 32;
2123    if (lane == 0) {
2124        for (uint m = 0; m < 8; m++) {
2125            uint t_top = out_tok_top + m;
2126            if (t_top < num_tokens) {
2127                device float* dst0 = outputs + t_top * rows + out_row_lo;
2128                device float* dst1 = outputs + t_top * rows + out_row_hi;
2129                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2130                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2131                for (uint j = 0; j < 8; j++) {
2132                    dst0[j] = float(src0[j]);
2133                    dst1[j] = float(src1[j]);
2134                }
2135            }
2136            uint t_bot = out_tok_bot + m;
2137            if (t_bot < num_tokens) {
2138                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2139                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2140                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2141                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2142                for (uint j = 0; j < 8; j++) {
2143                    dst2[j] = float(src2[j]);
2144                    dst3[j] = float(src3[j]);
2145                }
2146            }
2147        }
2148    }
2149}
2150
2151// ── add_bias_batch ─────────────────────────────────────────────────────
2152// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2153// Used for Qwen2 QKV bias after the fused qkv matmul.
2154//     out[token, i] += bias[i]    for i in 0..rows, token in 0..num_tokens
2155kernel void add_bias_batch(
2156    device float* out            [[buffer(0)]],  // [num_tokens, rows]
2157    device const float* bias     [[buffer(1)]],  // [rows]
2158    constant uint& num_tokens    [[buffer(2)]],
2159    constant uint& rows          [[buffer(3)]],
2160    uint id [[thread_position_in_grid]])
2161{
2162    uint total = num_tokens * rows;
2163    if (id >= total) return;
2164    uint i = id % rows;
2165    out[id] += bias[i];
2166}
2167
2168// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2169// Batched Q4_0 matrix-vector multiply for M input vectors.
2170// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2171kernel void matmul_vec_q4_batch(
2172    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2173    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2174    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2175    constant uint& num_tokens    [[buffer(3)]],  // M
2176    constant uint& rows          [[buffer(4)]],
2177    constant uint& cols          [[buffer(5)]],
2178    uint tgid [[threadgroup_position_in_grid]],
2179    uint tid [[thread_index_in_threadgroup]],
2180    uint simd_lane [[thread_index_in_simdgroup]],
2181    uint simd_id [[simdgroup_index_in_threadgroup]])
2182{
2183    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2184    uint token = tgid / row_tgs;
2185    uint tg_in_token = tgid % row_tgs;
2186    if (token >= num_tokens) return;
2187
2188    threadgroup float vec_tile[VEC_TILE_SIZE];
2189    device const float* input = inputs + token * cols;
2190    for (uint i = tid; i < cols; i += 256) {
2191        vec_tile[i] = input[i];
2192    }
2193    threadgroup_barrier(mem_flags::mem_threadgroup);
2194
2195    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2196    if (row_base >= rows) return;
2197
2198    uint blocks_per_row = cols / 32;
2199    uint row_bytes = blocks_per_row * 18;
2200
2201    device const uchar* r0 = matrix + row_base * row_bytes;
2202    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2203    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2204    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2205
2206    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2207
2208    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2209        uint bb = blk * 18;
2210        uint vb = blk * 32;
2211
2212        float sc0 = float(*(device const half*)(r0 + bb));
2213        float sc1 = float(*(device const half*)(r1 + bb));
2214        float sc2 = float(*(device const half*)(r2 + bb));
2215        float sc3 = float(*(device const half*)(r3 + bb));
2216
2217        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2218        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2219        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2220        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2221
2222        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2223        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2224        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2225        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2226        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2227        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2228        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2229        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2230
2231        float bd0=0, bd1=0, bd2=0, bd3=0;
2232        uchar4 b;
2233
2234        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2235        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2236        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2237        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2238
2239        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2240        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2241        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2242        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2243
2244        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2245        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2246        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2247        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2248
2249        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2250        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2251        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2252        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2253
2254        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2255    }
2256
2257    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2258    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2259
2260    device float* output = outputs + token * rows;
2261    if (simd_lane == 0) {
2262        if (row_base     < rows) output[row_base]     = sum0;
2263        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2264        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2265        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2266    }
2267}
2268
2269// ── copy_kv_batch ─────────────────────────────────────────────────────
2270// Copy K or V from a strided batch QKV buffer to the KV cache.
2271// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2272// dst layout: contiguous [max_seq, kv_dim] cache.
2273kernel void copy_kv_batch(
2274    device const float* src  [[buffer(0)]],  // batch QKV buffer
2275    device float* dst        [[buffer(1)]],  // KV cache
2276    constant uint& M         [[buffer(2)]],  // num tokens in batch
2277    constant uint& kv_dim    [[buffer(3)]],  // floats per KV vector
2278    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2279    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2280    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2281    uint id [[thread_position_in_grid]])
2282{
2283    uint total = M * kv_dim;
2284    if (id >= total) return;
2285    uint token = id / kv_dim;
2286    uint d = id % kv_dim;
2287    uint dst_off = (base_pos + token) * kv_dim + d;
2288    uint src_off = token * src_stride + src_offset + d;
2289    dst[dst_off] = src[src_off];
2290}
2291
2292// ── attention_batch ───────────────────────────────────────────────────
2293// Batched causal attention for prefill. Processes M tokens in one dispatch.
2294// Each threadgroup handles one (token, head) pair.
2295// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2296// Causal masking: token i can only attend to positions 0..base_pos+i.
2297kernel void attention_batch(
2298    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2299    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2300    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2301    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2302    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2303    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2304    constant uint& num_heads         [[buffer(6)]],
2305    constant uint& num_kv_heads      [[buffer(7)]],
2306    constant uint& head_dim          [[buffer(8)]],
2307    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2308    uint tgid [[threadgroup_position_in_grid]],
2309    uint tid [[thread_index_in_threadgroup]],
2310    uint simd_lane [[thread_index_in_simdgroup]],
2311    uint simd_id [[simdgroup_index_in_threadgroup]])
2312{
2313    // Grid: M * num_heads threadgroups
2314    uint token_idx = tgid / num_heads;
2315    uint head = tgid % num_heads;
2316    if (token_idx >= M) return;
2317
2318    uint kv_head = head / (num_heads / num_kv_heads);
2319    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2320
2321    // Q offset uses strided layout (from batch QKV buffer)
2322    uint q_off = token_idx * q_stride + head * head_dim;
2323    // Output is contiguous [M, num_heads * head_dim]
2324    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2325
2326    // Shared memory for attention scores
2327    threadgroup float scores[2048];
2328
2329    // Step 1: Q * K^T with simdgroup reduction
2330    for (uint s = simd_id; s < seq_len; s += 8) {
2331        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2332        float dot = 0.0;
2333        for (uint d = simd_lane; d < head_dim; d += 32) {
2334            dot += q_batch[q_off + d] * k_cache[k_off + d];
2335        }
2336        dot = simd_sum(dot);
2337        if (simd_lane == 0) {
2338            scores[s] = dot * fast::rsqrt(float(head_dim));
2339        }
2340    }
2341    threadgroup_barrier(mem_flags::mem_threadgroup);
2342
2343    // Step 2: Softmax (cooperative)
2344    float local_max = -INFINITY;
2345    for (uint s = tid; s < seq_len; s += 256) {
2346        local_max = max(local_max, scores[s]);
2347    }
2348    local_max = simd_max(local_max);
2349    threadgroup float shared_max[8];
2350    if (simd_lane == 0) shared_max[simd_id] = local_max;
2351    threadgroup_barrier(mem_flags::mem_threadgroup);
2352    if (tid == 0) {
2353        float m = shared_max[0];
2354        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2355        shared_max[0] = m;
2356    }
2357    threadgroup_barrier(mem_flags::mem_threadgroup);
2358    float max_val = shared_max[0];
2359
2360    float local_sum = 0.0;
2361    for (uint s = tid; s < seq_len; s += 256) {
2362        scores[s] = fast::exp(scores[s] - max_val);
2363        local_sum += scores[s];
2364    }
2365    local_sum = simd_sum(local_sum);
2366    threadgroup float shared_sum[8];
2367    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2368    threadgroup_barrier(mem_flags::mem_threadgroup);
2369    if (tid == 0) {
2370        float total = 0.0;
2371        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2372        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2373    }
2374    threadgroup_barrier(mem_flags::mem_threadgroup);
2375    float inv_sum = shared_sum[0];
2376    for (uint s = tid; s < seq_len; s += 256) {
2377        scores[s] *= inv_sum;
2378    }
2379    threadgroup_barrier(mem_flags::mem_threadgroup);
2380
2381    // Step 3: scores * V using float4 vectorized loads
2382    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2383    // This is much better than the scalar version where only 64 of 256 threads are active.
2384    uint v_stride = num_kv_heads * head_dim;
2385    uint head_dim4 = head_dim / 4;
2386    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2387        uint d = d4 * 4;
2388        float4 acc = float4(0.0);
2389        uint v_base = kv_head * head_dim + d;
2390        uint seq_len4 = seq_len & ~3u;
2391        for (uint s = 0; s < seq_len4; s += 4) {
2392            float sc0 = scores[s];
2393            float sc1 = scores[s + 1];
2394            float sc2 = scores[s + 2];
2395            float sc3 = scores[s + 3];
2396            acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
2397                 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
2398                 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
2399                 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
2400        }
2401        for (uint s = seq_len4; s < seq_len; s++) {
2402            acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
2403        }
2404        *(device float4*)(output_batch + out_off + d) = acc;
2405    }
2406    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2407    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2408        float acc = 0.0;
2409        uint v_base = kv_head * head_dim + d;
2410        for (uint s = 0; s < seq_len; s++) {
2411            acc += scores[s] * v_cache[s * v_stride + v_base];
2412        }
2413        output_batch[out_off + d] = acc;
2414    }
2415}
2416
2417// ── attention_flash_batch ─────────────────────────────────────────────
2418// Streaming attention with online softmax.  Same grid as attention_batch
2419// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2420// matrix is never materialized.  K/V positions are processed in a tile of
2421// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2422// the standard flash-attention recurrence:
2423//
2424//     m_new   = max(m_old, tile_max)
2425//     alpha   = exp(m_old - m_new)
2426//     l_new   = alpha * l_old + sum(exp(S - m_new))
2427//     O_new   = alpha * O_old + sum(exp(S - m_new) * V)
2428//     O_final = O / l
2429//
2430// This removes the `scores[2048]` cap in attention_batch (which silently
2431// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2432// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2433//
2434// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2435constant constexpr uint FLASH_K_TILE = 32;
2436constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2437
2438kernel void attention_flash_batch(
2439    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2440    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2441    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2442    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2443    constant uint& M                 [[buffer(4)]],
2444    constant uint& base_pos          [[buffer(5)]],
2445    constant uint& num_heads         [[buffer(6)]],
2446    constant uint& num_kv_heads      [[buffer(7)]],
2447    constant uint& head_dim          [[buffer(8)]],
2448    constant uint& q_stride          [[buffer(9)]],
2449    uint tgid [[threadgroup_position_in_grid]],
2450    uint tid [[thread_index_in_threadgroup]],
2451    uint simd_lane [[thread_index_in_simdgroup]],
2452    uint simd_id [[simdgroup_index_in_threadgroup]])
2453{
2454    uint token_idx = tgid / num_heads;
2455    uint head = tgid % num_heads;
2456    if (token_idx >= M) return;
2457
2458    uint kv_head = head / (num_heads / num_kv_heads);
2459    uint seq_len = base_pos + token_idx + 1;  // causal: attend to [0, base_pos + token_idx]
2460    uint q_off = token_idx * q_stride + head * head_dim;
2461    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2462
2463    // Threadgroup state:
2464    //   q_sh:      Q vector for this (token, head), loaded once
2465    //   o_sh:      running output vector, updated each K tile
2466    //   scores_sh: scores for the current K tile only
2467    //   stats:     [running max, running sum]  (see flash-attention recurrence)
2468    threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2469    threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2470    threadgroup float scores_sh[FLASH_K_TILE];
2471    threadgroup float stats[2];
2472    threadgroup float sg_scratch[8];  // simdgroup-level reduction buffer
2473
2474    // --- Load Q (one row) and zero the running O ---
2475    for (uint d = tid; d < head_dim; d += 256) {
2476        q_sh[d] = q_batch[q_off + d];
2477        o_sh[d] = 0.0f;
2478    }
2479    if (tid == 0) {
2480        stats[0] = -INFINITY;
2481        stats[1] = 0.0f;
2482    }
2483    threadgroup_barrier(mem_flags::mem_threadgroup);
2484
2485    float scale = fast::rsqrt(float(head_dim));
2486    uint v_stride = num_kv_heads * head_dim;
2487    uint v_base = kv_head * head_dim;
2488
2489    // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2490    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2491        uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2492
2493        // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2494        // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2495        for (uint ti = simd_id; ti < tile_n; ti += 8) {
2496            uint k_off = (kv_base + ti) * v_stride + v_base;  // same layout as V stride
2497            float dot = 0.0f;
2498            for (uint d = simd_lane; d < head_dim; d += 32) {
2499                dot += q_sh[d] * k_cache[k_off + d];
2500            }
2501            dot = simd_sum(dot);
2502            if (simd_lane == 0) {
2503                scores_sh[ti] = dot * scale;
2504            }
2505        }
2506        threadgroup_barrier(mem_flags::mem_threadgroup);
2507
2508        // [2] Tile max via cooperative reduction.
2509        float local_max = -INFINITY;
2510        for (uint s = tid; s < tile_n; s += 256) {
2511            local_max = max(local_max, scores_sh[s]);
2512        }
2513        local_max = simd_max(local_max);
2514        if (simd_lane == 0) {
2515            sg_scratch[simd_id] = local_max;
2516        }
2517        threadgroup_barrier(mem_flags::mem_threadgroup);
2518        // [3] Merge with running max, compute alpha, rescale running l.
2519        float m_new;
2520        float alpha;
2521        if (tid == 0) {
2522            float tile_max = sg_scratch[0];
2523            for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2524            float m_old = stats[0];
2525            m_new = max(m_old, tile_max);
2526            // First iteration: m_old = -inf → alpha = 0 (reset O).
2527            alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2528            stats[0] = m_new;
2529            stats[1] *= alpha;
2530            // Broadcast via sg_scratch.
2531            sg_scratch[0] = alpha;
2532            sg_scratch[1] = m_new;
2533        }
2534        threadgroup_barrier(mem_flags::mem_threadgroup);
2535        alpha = sg_scratch[0];
2536        m_new = sg_scratch[1];
2537
2538        // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2539        for (uint d = tid; d < head_dim; d += 256) {
2540            o_sh[d] *= alpha;
2541        }
2542        for (uint s = tid; s < tile_n; s += 256) {
2543            scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2544        }
2545        threadgroup_barrier(mem_flags::mem_threadgroup);
2546
2547        // [5] Tile sum → update running l.
2548        float local_sum = 0.0f;
2549        for (uint s = tid; s < tile_n; s += 256) {
2550            local_sum += scores_sh[s];
2551        }
2552        local_sum = simd_sum(local_sum);
2553        if (simd_lane == 0) {
2554            sg_scratch[simd_id] = local_sum;
2555        }
2556        threadgroup_barrier(mem_flags::mem_threadgroup);
2557        if (tid == 0) {
2558            float tile_sum = 0.0f;
2559            for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2560            stats[1] += tile_sum;
2561        }
2562        threadgroup_barrier(mem_flags::mem_threadgroup);
2563
2564        // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2565        for (uint d = tid; d < head_dim; d += 256) {
2566            float acc = 0.0f;
2567            for (uint s = 0; s < tile_n; s++) {
2568                acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2569            }
2570            o_sh[d] += acc;
2571        }
2572        threadgroup_barrier(mem_flags::mem_threadgroup);
2573    }
2574
2575    // --- Normalize and write output ---
2576    float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2577    for (uint d = tid; d < head_dim; d += 256) {
2578        output_batch[out_off + d] = o_sh[d] * inv_l;
2579    }
2580}
2581
2582// ── rope_qk_batch ─────────────────────────────────────────────────────
2583// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2584// launch + memory barrier per layer. Both Q and K live in the same
2585// qkv_data buffer at different offsets within each token's row.
2586// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2587kernel void rope_qk_batch(
2588    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2589    constant uint& M                 [[buffer(1)]],   // num tokens
2590    constant uint& base_pos          [[buffer(2)]],   // starting position
2591    constant uint& num_q_heads       [[buffer(3)]],
2592    constant uint& num_kv_heads      [[buffer(4)]],
2593    constant uint& head_dim          [[buffer(5)]],
2594    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2595    constant float& theta            [[buffer(7)]],
2596    uint id [[thread_position_in_grid]])
2597{
2598    uint half_dim = head_dim / 2;
2599    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2600    uint token = id / total_pairs;
2601    uint pair = id % total_pairs;
2602    if (token >= M) return;
2603
2604    uint pos = base_pos + token;
2605    uint q_pairs = num_q_heads * half_dim;
2606
2607    uint h, i, offset;
2608    if (pair < q_pairs) {
2609        // Q head
2610        h = pair / half_dim;
2611        i = pair % half_dim;
2612        offset = token * qkv_stride + h * head_dim + i * 2;
2613    } else {
2614        // K head
2615        uint kp = pair - q_pairs;
2616        h = kp / half_dim;
2617        i = kp % half_dim;
2618        uint k_start = num_q_heads * head_dim;
2619        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2620    }
2621
2622    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2623    float angle = float(pos) * freq;
2624    float cos_val = cos(angle);
2625    float sin_val = sin(angle);
2626
2627    float x0 = qkv_data[offset];
2628    float x1 = qkv_data[offset + 1];
2629    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2630    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2631}
2632
2633// ── copy_kv_both_batch ────────────────────────────────────────────────
2634// Fused K+V cache copy in a single dispatch: copies both K and V from
2635// the strided batch QKV buffer to their respective KV cache buffers.
2636// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2637kernel void copy_kv_both_batch(
2638    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride]
2639    device float* k_dst        [[buffer(1)]],  // K cache [max_seq, kv_dim]
2640    device float* v_dst        [[buffer(2)]],  // V cache [max_seq, kv_dim]
2641    constant uint& M           [[buffer(3)]],  // num tokens in batch
2642    constant uint& kv_dim      [[buffer(4)]],  // floats per KV vector
2643    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2644    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2645    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2646    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2647    uint id [[thread_position_in_grid]])
2648{
2649    // Total elements = M * kv_dim * 2 (K + V)
2650    uint total_kv = M * kv_dim;
2651    if (id >= total_kv * 2) return;
2652
2653    uint is_v = id / total_kv;        // 0 = K, 1 = V
2654    uint local_id = id % total_kv;
2655    uint token = local_id / kv_dim;
2656    uint d = local_id % kv_dim;
2657
2658    uint dst_off = (base_pos + token) * kv_dim + d;
2659    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2660
2661    if (is_v) {
2662        v_dst[dst_off] = src[src_off];
2663    } else {
2664        k_dst[dst_off] = src[src_off];
2665    }
2666}
2667"#
2668    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2669}
2670
2671// ---------------------------------------------------------------------------
2672// model.rs generation
2673// ---------------------------------------------------------------------------
2674
2675fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2676    let mut code = String::with_capacity(48 * 1024);
2677    emit_model_header(&mut code, config)?;
2678    emit_metal_model_struct(&mut code, config)?;
2679    emit_layer_buffers_struct(&mut code, config)?;
2680    emit_metal_model_impl(&mut code, config)?;
2681    emit_helper_functions(&mut code)?;
2682    Ok(code)
2683}
2684
2685fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2686    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2687    writeln!(
2688        code,
2689        "//! Model: {} ({} layers, hidden={})",
2690        config.architecture, config.num_layers, config.hidden_size
2691    )?;
2692    writeln!(code, "//!")?;
2693    writeln!(
2694        code,
2695        "//! Uses native Metal compute pipelines via the metal crate."
2696    )?;
2697    writeln!(code)?;
2698    writeln!(code, "#![allow(dead_code)]")?;
2699    writeln!(code)?;
2700    writeln!(code, "use metal::*;")?;
2701    writeln!(code, "#[allow(unused_imports)]")?;
2702    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2703    writeln!(code, "use std::mem;")?;
2704    writeln!(code)?;
2705
2706    // Model constants
2707    writeln!(
2708        code,
2709        "// ── Model constants ──────────────────────────────────"
2710    )?;
2711    writeln!(
2712        code,
2713        "pub const HIDDEN_SIZE: usize = {};",
2714        config.hidden_size
2715    )?;
2716    writeln!(
2717        code,
2718        "pub const INTERMEDIATE_SIZE: usize = {};",
2719        config.intermediate_size
2720    )?;
2721    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2722    writeln!(
2723        code,
2724        "pub const NUM_HEADS: usize = {};",
2725        config.num_attention_heads
2726    )?;
2727    writeln!(
2728        code,
2729        "pub const NUM_KV_HEADS: usize = {};",
2730        config.num_kv_heads
2731    )?;
2732    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
2733    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
2734    let effective_seq_len = config.max_seq_len.min(4096);
2735    writeln!(
2736        code,
2737        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
2738        effective_seq_len, config.max_seq_len
2739    )?;
2740    writeln!(
2741        code,
2742        "pub const RMS_NORM_EPS: f32 = {:e};",
2743        config.rms_norm_eps
2744    )?;
2745    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
2746    writeln!(
2747        code,
2748        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
2749    )?;
2750    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
2751    writeln!(code)?;
2752
2753    Ok(())
2754}
2755
2756fn emit_metal_model_struct(
2757    code: &mut String,
2758    config: &ModelConfig,
2759) -> Result<(), MetalCodegenError> {
2760    writeln!(
2761        code,
2762        "// ── MetalModel ──────────────────────────────────────────"
2763    )?;
2764    writeln!(code)?;
2765    writeln!(
2766        code,
2767        "/// Metal-accelerated transformer model for Apple Silicon."
2768    )?;
2769    writeln!(code, "///")?;
2770    writeln!(
2771        code,
2772        "/// Uses unified memory for zero-copy weight access and native Metal"
2773    )?;
2774    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
2775    writeln!(code, "pub struct MetalModel {{")?;
2776    writeln!(code, "    device: Device,")?;
2777    writeln!(code, "    queue: CommandQueue,")?;
2778    writeln!(code)?;
2779    writeln!(code, "    // ── Compute pipelines ──")?;
2780    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
2781    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
2782    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
2783    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
2784    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
2785    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
2786    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
2787    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
2788    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
2789    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
2790    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
2791    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
2792    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
2793    writeln!(code)?;
2794    writeln!(code, "    // ── Batched prefill pipelines ──")?;
2795    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
2796    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
2797    writeln!(
2798        code,
2799        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
2800    )?;
2801    writeln!(
2802        code,
2803        "    matmul_q8_mma_pipeline: ComputePipelineState,"
2804    )?;
2805    writeln!(
2806        code,
2807        "    matmul_q8_mma32_pipeline: ComputePipelineState,"
2808    )?;
2809    writeln!(
2810        code,
2811        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
2812    )?;
2813    writeln!(
2814        code,
2815        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
2816    )?;
2817    writeln!(
2818        code,
2819        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
2820    )?;
2821    if config.qkv_bias {
2822        writeln!(
2823            code,
2824            "    add_bias_batch_pipeline: ComputePipelineState,"
2825        )?;
2826    }
2827    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
2828    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
2829    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
2830    writeln!(
2831        code,
2832        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
2833    )?;
2834    writeln!(
2835        code,
2836        "    add_inplace_batch_pipeline: ComputePipelineState,"
2837    )?;
2838    writeln!(
2839        code,
2840        "    copy_embedding_batch_pipeline: ComputePipelineState,"
2841    )?;
2842    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
2843    writeln!(
2844        code,
2845        "    attention_flash_batch_pipeline: ComputePipelineState,"
2846    )?;
2847    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
2848    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
2849    writeln!(
2850        code,
2851        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
2852    )?;
2853    writeln!(code)?;
2854    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
2855    writeln!(
2856        code,
2857        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
2858    )?;
2859    writeln!(code, "    embed_buf: Buffer,")?;
2860    writeln!(code)?;
2861    writeln!(code, "    /// Per-layer weight buffers")?;
2862    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
2863    writeln!(code)?;
2864    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
2865    writeln!(code, "    norm_buf: Buffer,")?;
2866    writeln!(code)?;
2867    writeln!(
2868        code,
2869        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
2870    )?;
2871    writeln!(code, "    lm_head_buf: Buffer,")?;
2872    writeln!(code)?;
2873    writeln!(
2874        code,
2875        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
2876    )?;
2877    writeln!(code, "    hidden_buf: Buffer,")?;
2878    writeln!(code, "    residual_buf: Buffer,")?;
2879    writeln!(code, "    normed_buf: Buffer,")?;
2880    writeln!(
2881        code,
2882        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
2883    )?;
2884    writeln!(code, "    qkv_buf: Buffer,")?;
2885    writeln!(code, "    attn_out_buf: Buffer,")?;
2886    writeln!(code, "    attn_proj_buf: Buffer,")?;
2887    writeln!(
2888        code,
2889        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
2890    )?;
2891    writeln!(code, "    gate_up_buf: Buffer,")?;
2892    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
2893    writeln!(code, "    ffn_out_buf: Buffer,")?;
2894    writeln!(code, "    add_tmp_buf: Buffer,")?;
2895    writeln!(code, "    logits_buf: Buffer,")?;
2896    writeln!(code)?;
2897    writeln!(code, "    // ── Batched prefill working buffers ──")?;
2898    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
2899    writeln!(code, "    batch_hidden_buf: Buffer,")?;
2900    writeln!(
2901        code,
2902        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
2903    )?;
2904    writeln!(code, "    batch_residual_buf: Buffer,")?;
2905    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
2906    writeln!(code, "    batch_qkv_buf: Buffer,")?;
2907    writeln!(
2908        code,
2909        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
2910    )?;
2911    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
2912    writeln!(
2913        code,
2914        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
2915    )?;
2916    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
2917    writeln!(
2918        code,
2919        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
2920    )?;
2921    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
2922    writeln!(
2923        code,
2924        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
2925    )?;
2926    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
2927    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
2928    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
2929    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
2930    writeln!(code, "    batch_tokens_buf: Buffer,")?;
2931    writeln!(code, "    /// Positions buffer for batch RoPE")?;
2932    writeln!(code, "    batch_positions_buf: Buffer,")?;
2933    writeln!(code)?;
2934    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
2935    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
2936    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
2937    writeln!(code)?;
2938    writeln!(code, "    // ── Inference state ──")?;
2939    writeln!(code, "    pos: usize,")?;
2940    writeln!(code)?;
2941    writeln!(
2942        code,
2943        "    /// Previous command buffer for double-buffered prefill."
2944    )?;
2945    writeln!(
2946        code,
2947        "    /// While the GPU executes token N, the CPU can encode token N+1."
2948    )?;
2949    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
2950    writeln!(code, "}}")?;
2951    writeln!(code)?;
2952
2953    Ok(())
2954}
2955
2956fn emit_layer_buffers_struct(
2957    code: &mut String,
2958    config: &ModelConfig,
2959) -> Result<(), MetalCodegenError> {
2960    writeln!(
2961        code,
2962        "/// Per-layer weight buffers for attention and FFN projections."
2963    )?;
2964    writeln!(code, "struct LayerBuffers {{")?;
2965    writeln!(code, "    attn_norm: Buffer,")?;
2966    writeln!(
2967        code,
2968        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
2969    )?;
2970    writeln!(code, "    qkv_weight: Buffer,")?;
2971    if config.qkv_bias {
2972        writeln!(
2973            code,
2974            "    /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
2975        )?;
2976        writeln!(code, "    qkv_bias: Buffer,")?;
2977    }
2978    writeln!(code, "    o_weight: Buffer,")?;
2979    writeln!(code, "    ffn_norm: Buffer,")?;
2980    writeln!(
2981        code,
2982        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
2983    )?;
2984    writeln!(code, "    gate_up_weight: Buffer,")?;
2985    writeln!(code, "    down_weight: Buffer,")?;
2986    writeln!(code, "}}")?;
2987    writeln!(code)?;
2988
2989    Ok(())
2990}
2991
2992fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2993    let hidden = config.hidden_size;
2994    let intermediate = config.intermediate_size;
2995    let _num_layers = config.num_layers;
2996    let num_heads = config.num_attention_heads;
2997    let num_kv_heads = config.num_kv_heads;
2998    let head_dim = config.head_dim;
2999    let vocab = config.vocab_size;
3000    let effective_seq_len = config.max_seq_len.min(4096);
3001    let is_q8 = config.dtype == DType::Q8_0;
3002    let is_q4 = config.dtype == DType::Q4_0;
3003    let kv_dim = num_kv_heads * head_dim;
3004
3005    writeln!(code, "impl MetalModel {{")?;
3006
3007    // ── new() ──
3008    writeln!(
3009        code,
3010        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3011    )?;
3012    writeln!(code, "    ///")?;
3013    writeln!(
3014        code,
3015        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
3016    )?;
3017    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
3018    writeln!(
3019        code,
3020        "        let device = Device::system_default().expect(\"no Metal device found\");"
3021    )?;
3022    writeln!(code, "        let queue = device.new_command_queue();")?;
3023    writeln!(code)?;
3024
3025    // Compile shaders
3026    writeln!(
3027        code,
3028        "        // Compile Metal shaders from embedded source"
3029    )?;
3030    writeln!(
3031        code,
3032        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
3033    )?;
3034    writeln!(
3035        code,
3036        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3037    )?;
3038    writeln!(
3039        code,
3040        "            .expect(\"failed to compile Metal shaders\");"
3041    )?;
3042    writeln!(code)?;
3043
3044    // Create compute pipelines
3045    writeln!(code, "        // Create compute pipelines")?;
3046    for (var, fn_name) in [
3047        ("matmul_pipeline", "matmul_vec"),
3048        ("matmul_q8_pipeline", "matmul_vec_q8"),
3049        ("matmul_q4_pipeline", "matmul_vec_q4"),
3050        ("rms_norm_pipeline", "rms_norm"),
3051        ("rope_pipeline", "rope"),
3052        ("softmax_pipeline", "softmax"),
3053        ("silu_mul_pipeline", "silu_mul"),
3054        ("silu_mul_fused_pipeline", "silu_mul_fused"),
3055        ("add_pipeline", "elementwise_add"),
3056        ("attention_pipeline", "attention"),
3057        ("add_inplace_pipeline", "add_inplace"),
3058        ("copy_pipeline", "copy_buffer"),
3059        ("copy_offset_pipeline", "copy_offset"),
3060        ("matmul_batch_pipeline", "matmul_vec_batch"),
3061        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3062        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3063        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3064        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3065        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3066        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3067        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3068        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3069        ("rms_norm_batch_pipeline", "rms_norm_batch"),
3070        ("rope_batch_pipeline", "rope_batch"),
3071        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3072        ("add_inplace_batch_pipeline", "add_inplace_batch"),
3073        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3074        ("attention_batch_pipeline", "attention_batch"),
3075        ("attention_flash_batch_pipeline", "attention_flash_batch"),
3076        ("copy_kv_batch_pipeline", "copy_kv_batch"),
3077        ("rope_qk_batch_pipeline", "rope_qk_batch"),
3078        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3079    ] {
3080        writeln!(
3081            code,
3082            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3083        )?;
3084    }
3085    if config.qkv_bias {
3086        writeln!(
3087            code,
3088            "        let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3089        )?;
3090    }
3091    writeln!(code)?;
3092
3093    // Weight loading
3094    writeln!(
3095        code,
3096        "        // Load weights into Metal shared-memory buffers"
3097    )?;
3098    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
3099    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3100    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
3101    writeln!(code)?;
3102    writeln!(
3103        code,
3104        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
3105    )?;
3106    writeln!(code)?;
3107    writeln!(
3108        code,
3109        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3110    )?;
3111    writeln!(
3112        code,
3113        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3114    )?;
3115    writeln!(code, "            let byte_len = n * f32_size;")?;
3116    writeln!(code, "            let cur = cursor.get();")?;
3117    writeln!(
3118        code,
3119        "            let data = &weights[cur..cur + byte_len];"
3120    )?;
3121    writeln!(code, "            cursor.set(cur + byte_len);")?;
3122    writeln!(code, "            device.new_buffer_with_data(")?;
3123    writeln!(code, "                data.as_ptr() as *const _,")?;
3124    writeln!(code, "                byte_len as u64,")?;
3125    writeln!(
3126        code,
3127        "                MTLResourceOptions::StorageModeShared,"
3128    )?;
3129    writeln!(code, "            )")?;
3130    writeln!(code, "        }};")?;
3131    writeln!(code)?;
3132
3133    if is_q8 {
3134        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
3135        // We load them directly into Metal buffers without dequantizing,
3136        // and use the matmul_vec_q8 shader that operates on quantized data.
3137        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
3138        writeln!(
3139            code,
3140            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3141        )?;
3142        writeln!(
3143            code,
3144            "        // as raw bytes into a Metal buffer (no dequantization)."
3145        )?;
3146        writeln!(
3147            code,
3148            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3149        )?;
3150        writeln!(
3151            code,
3152            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3153        )?;
3154        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3155        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3156        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3157        writeln!(code, "            let cur = cursor.get();")?;
3158        writeln!(
3159            code,
3160            "            let data = &weights[cur..cur + total_raw];"
3161        )?;
3162        writeln!(code, "            cursor.set(cur + total_raw);")?;
3163        writeln!(code, "            device.new_buffer_with_data(")?;
3164        writeln!(code, "                data.as_ptr() as *const _,")?;
3165        writeln!(code, "                total_raw as u64,")?;
3166        writeln!(
3167            code,
3168            "                MTLResourceOptions::StorageModeShared,"
3169        )?;
3170        writeln!(code, "            )")?;
3171        writeln!(code, "        }};")?;
3172        writeln!(code)?;
3173        writeln!(
3174            code,
3175            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3176        )?;
3177        writeln!(
3178            code,
3179            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3180        )?;
3181        writeln!(
3182            code,
3183            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3184        )?;
3185        writeln!(
3186            code,
3187            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3188        )?;
3189        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3190        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3191        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3192        writeln!(code, "            let cur = cursor.get();")?;
3193        writeln!(
3194            code,
3195            "            let data = &weights[cur..cur + total_raw];"
3196        )?;
3197        writeln!(code, "            cursor.set(cur + total_raw);")?;
3198        writeln!(code, "            device.new_buffer_with_data(")?;
3199        writeln!(code, "                data.as_ptr() as *const _,")?;
3200        writeln!(code, "                total_raw as u64,")?;
3201        writeln!(
3202            code,
3203            "                MTLResourceOptions::StorageModeShared,"
3204        )?;
3205        writeln!(code, "            )")?;
3206        writeln!(code, "        }};")?;
3207        writeln!(code)?;
3208    }
3209
3210    if is_q4 {
3211        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3212        // We load them directly into Metal buffers without dequantizing,
3213        // and use the matmul_vec_q4 shader that operates on quantized data.
3214        // This quarters GPU memory usage vs f32 dequantization.
3215        writeln!(
3216            code,
3217            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3218        )?;
3219        writeln!(
3220            code,
3221            "        // as raw bytes into a Metal buffer (no dequantization)."
3222        )?;
3223        writeln!(
3224            code,
3225            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3226        )?;
3227        writeln!(
3228            code,
3229            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3230        )?;
3231        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3232        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3233        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3234        writeln!(code, "            let cur = cursor.get();")?;
3235        writeln!(
3236            code,
3237            "            let data = &weights[cur..cur + total_raw];"
3238        )?;
3239        writeln!(code, "            cursor.set(cur + total_raw);")?;
3240        writeln!(code, "            device.new_buffer_with_data(")?;
3241        writeln!(code, "                data.as_ptr() as *const _,")?;
3242        writeln!(code, "                total_raw as u64,")?;
3243        writeln!(
3244            code,
3245            "                MTLResourceOptions::StorageModeShared,"
3246        )?;
3247        writeln!(code, "            )")?;
3248        writeln!(code, "        }};")?;
3249        writeln!(code)?;
3250        writeln!(
3251            code,
3252            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3253        )?;
3254        writeln!(
3255            code,
3256            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3257        )?;
3258        writeln!(
3259            code,
3260            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3261        )?;
3262        writeln!(
3263            code,
3264            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3265        )?;
3266        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3267        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3268        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3269        writeln!(code, "            let cur = cursor.get();")?;
3270        writeln!(
3271            code,
3272            "            let data = &weights[cur..cur + total_raw];"
3273        )?;
3274        writeln!(code, "            cursor.set(cur + total_raw);")?;
3275        writeln!(code, "            device.new_buffer_with_data(")?;
3276        writeln!(code, "                data.as_ptr() as *const _,")?;
3277        writeln!(code, "                total_raw as u64,")?;
3278        writeln!(
3279            code,
3280            "                MTLResourceOptions::StorageModeShared,"
3281        )?;
3282        writeln!(code, "            )")?;
3283        writeln!(code, "        }};")?;
3284        writeln!(code)?;
3285    }
3286
3287    writeln!(
3288        code,
3289        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3290    )?;
3291    writeln!(code)?;
3292
3293    // Per-layer weights
3294    writeln!(
3295        code,
3296        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3297    )?;
3298    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3299
3300    // attn_norm is always f32
3301    writeln!(
3302        code,
3303        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3304    )?;
3305
3306    let qkv_rows = hidden + 2 * kv_dim;
3307    if is_q8 {
3308        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3309        writeln!(
3310            code,
3311            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3312        )?;
3313        if config.qkv_bias {
3314            writeln!(
3315                code,
3316                "            // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3317            )?;
3318            writeln!(
3319                code,
3320                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3321            )?;
3322        }
3323        writeln!(
3324            code,
3325            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3326        )?;
3327    } else if is_q4 {
3328        writeln!(
3329            code,
3330            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3331        )?;
3332        if config.qkv_bias {
3333            writeln!(
3334                code,
3335                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3336            )?;
3337        }
3338        writeln!(
3339            code,
3340            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3341        )?;
3342    } else {
3343        writeln!(
3344            code,
3345            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3346        )?;
3347        if config.qkv_bias {
3348            writeln!(
3349                code,
3350                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3351            )?;
3352        }
3353        writeln!(
3354            code,
3355            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3356        )?;
3357    }
3358
3359    // ffn_norm is always f32
3360    writeln!(
3361        code,
3362        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3363    )?;
3364
3365    let gate_up_rows = 2 * intermediate;
3366    if is_q8 {
3367        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3368        writeln!(
3369            code,
3370            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3371        )?;
3372        writeln!(
3373            code,
3374            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3375        )?;
3376    } else if is_q4 {
3377        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3378        writeln!(
3379            code,
3380            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3381        )?;
3382        writeln!(
3383            code,
3384            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3385        )?;
3386    } else {
3387        // Fused gate+up weight: read both as a single contiguous f32 buffer
3388        writeln!(
3389            code,
3390            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3391        )?;
3392        writeln!(
3393            code,
3394            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3395        )?;
3396    }
3397
3398    writeln!(code, "            layers.push(LayerBuffers {{")?;
3399    writeln!(code, "                attn_norm,")?;
3400    writeln!(code, "                qkv_weight,")?;
3401    if config.qkv_bias {
3402        writeln!(code, "                qkv_bias,")?;
3403    }
3404    writeln!(code, "                o_weight,")?;
3405    writeln!(code, "                ffn_norm,")?;
3406    writeln!(code, "                gate_up_weight,")?;
3407    writeln!(code, "                down_weight,")?;
3408    writeln!(code, "            }});")?;
3409    writeln!(code, "        }}")?;
3410    writeln!(code)?;
3411
3412    // final_norm is always f32
3413    writeln!(
3414        code,
3415        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3416    )?;
3417    writeln!(code)?;
3418
3419    // lm_head
3420    if is_q8 {
3421        writeln!(
3422            code,
3423            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3424        )?;
3425    } else if is_q4 {
3426        writeln!(
3427            code,
3428            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3429        )?;
3430    } else {
3431        writeln!(
3432            code,
3433            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3434        )?;
3435    }
3436    writeln!(code)?;
3437
3438    // Working buffers
3439    let hidden_bytes = hidden * 4;
3440    let _kv_dim_bytes = kv_dim * 4;
3441    let intermediate_bytes = intermediate * 4;
3442    let vocab_bytes = vocab * 4;
3443    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
3444
3445    writeln!(
3446        code,
3447        "        // Allocate working buffers (shared memory for zero-copy)"
3448    )?;
3449    writeln!(
3450        code,
3451        "        let opts = MTLResourceOptions::StorageModeShared;"
3452    )?;
3453    writeln!(
3454        code,
3455        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3456    )?;
3457    writeln!(
3458        code,
3459        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3460    )?;
3461    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3462    writeln!(
3463        code,
3464        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3465    )?;
3466    writeln!(
3467        code,
3468        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3469    )?;
3470    writeln!(
3471        code,
3472        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3473    )?;
3474    writeln!(
3475        code,
3476        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3477    )?;
3478    writeln!(
3479        code,
3480        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3481    )?;
3482    let gate_up_buf_bytes = 2 * intermediate * 4;
3483    writeln!(
3484        code,
3485        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3486    )?;
3487    writeln!(
3488        code,
3489        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3490    )?;
3491    writeln!(
3492        code,
3493        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3494    )?;
3495    writeln!(
3496        code,
3497        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3498    )?;
3499    writeln!(
3500        code,
3501        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3502    )?;
3503    writeln!(
3504        code,
3505        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3506    )?;
3507    writeln!(code)?;
3508
3509    // Batch prefill working buffers
3510    let batch_hidden_bytes = hidden * 4; // per-token
3511    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3512    let batch_gate_up_bytes = 2 * intermediate * 4;
3513    let batch_intermediate_bytes = intermediate * 4;
3514    writeln!(
3515        code,
3516        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3517    )?;
3518    writeln!(
3519        code,
3520        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3521    )?;
3522    writeln!(
3523        code,
3524        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3525    )?;
3526    writeln!(
3527        code,
3528        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3529    )?;
3530    writeln!(
3531        code,
3532        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3533    )?;
3534    writeln!(
3535        code,
3536        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3537    )?;
3538    writeln!(
3539        code,
3540        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3541    )?;
3542    writeln!(
3543        code,
3544        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3545    )?;
3546    writeln!(
3547        code,
3548        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3549    )?;
3550    writeln!(
3551        code,
3552        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3553    )?;
3554    writeln!(
3555        code,
3556        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3557    )?;
3558    writeln!(code)?;
3559
3560    // KV cache buffers
3561    writeln!(code, "        // KV cache buffers (per-layer)")?;
3562    writeln!(
3563        code,
3564        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3565    )?;
3566    writeln!(
3567        code,
3568        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3569    )?;
3570    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3571    writeln!(
3572        code,
3573        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3574    )?;
3575    writeln!(
3576        code,
3577        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3578    )?;
3579    writeln!(code, "        }}")?;
3580    writeln!(code)?;
3581
3582    writeln!(code, "        Self {{")?;
3583    writeln!(code, "            device,")?;
3584    writeln!(code, "            queue,")?;
3585    writeln!(code, "            matmul_pipeline,")?;
3586    writeln!(code, "            matmul_q8_pipeline,")?;
3587    writeln!(code, "            matmul_q4_pipeline,")?;
3588    writeln!(code, "            rms_norm_pipeline,")?;
3589    writeln!(code, "            rope_pipeline,")?;
3590    writeln!(code, "            softmax_pipeline,")?;
3591    writeln!(code, "            silu_mul_pipeline,")?;
3592    writeln!(code, "            silu_mul_fused_pipeline,")?;
3593    writeln!(code, "            add_pipeline,")?;
3594    writeln!(code, "            attention_pipeline,")?;
3595    writeln!(code, "            add_inplace_pipeline,")?;
3596    writeln!(code, "            copy_pipeline,")?;
3597    writeln!(code, "            copy_offset_pipeline,")?;
3598    writeln!(code, "            matmul_batch_pipeline,")?;
3599    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3600    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3601    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3602    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3603    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3604    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3605    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3606    if config.qkv_bias {
3607        writeln!(code, "            add_bias_batch_pipeline,")?;
3608    }
3609    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3610    writeln!(code, "            rms_norm_batch_pipeline,")?;
3611    writeln!(code, "            rope_batch_pipeline,")?;
3612    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3613    writeln!(code, "            add_inplace_batch_pipeline,")?;
3614    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3615    writeln!(code, "            attention_batch_pipeline,")?;
3616    writeln!(code, "            attention_flash_batch_pipeline,")?;
3617    writeln!(code, "            copy_kv_batch_pipeline,")?;
3618    writeln!(code, "            rope_qk_batch_pipeline,")?;
3619    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3620    writeln!(code, "            embed_buf,")?;
3621    writeln!(code, "            layers,")?;
3622    writeln!(code, "            norm_buf,")?;
3623    writeln!(code, "            lm_head_buf,")?;
3624    writeln!(code, "            hidden_buf,")?;
3625    writeln!(code, "            residual_buf,")?;
3626    writeln!(code, "            normed_buf,")?;
3627    writeln!(code, "            qkv_buf,")?;
3628    writeln!(code, "            attn_out_buf,")?;
3629    writeln!(code, "            attn_proj_buf,")?;
3630    writeln!(code, "            gate_up_buf,")?;
3631    writeln!(code, "            ffn_hidden_buf,")?;
3632    writeln!(code, "            ffn_out_buf,")?;
3633    writeln!(code, "            add_tmp_buf,")?;
3634    writeln!(code, "            logits_buf,")?;
3635    writeln!(code, "            batch_hidden_buf,")?;
3636    writeln!(code, "            batch_residual_buf,")?;
3637    writeln!(code, "            batch_qkv_buf,")?;
3638    writeln!(code, "            batch_attn_out_buf,")?;
3639    writeln!(code, "            batch_attn_proj_buf,")?;
3640    writeln!(code, "            batch_gate_up_buf,")?;
3641    writeln!(code, "            batch_ffn_hidden_buf,")?;
3642    writeln!(code, "            batch_ffn_out_buf,")?;
3643    writeln!(code, "            batch_tokens_buf,")?;
3644    writeln!(code, "            batch_positions_buf,")?;
3645    writeln!(code, "            k_cache,")?;
3646    writeln!(code, "            v_cache,")?;
3647    writeln!(code, "            pos: 0,")?;
3648    writeln!(code, "            prev_cmd: None,")?;
3649    writeln!(code, "        }}")?;
3650    writeln!(code, "    }}")?;
3651    writeln!(code)?;
3652
3653    // ── forward() ──
3654    writeln!(
3655        code,
3656        "    /// Run the forward pass for a single token at the current position."
3657    )?;
3658    writeln!(code, "    ///")?;
3659    writeln!(
3660        code,
3661        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3662    )?;
3663    writeln!(code, "    ///")?;
3664    writeln!(
3665        code,
3666        "    /// All GPU operations are encoded into a single command buffer and"
3667    )?;
3668    writeln!(
3669        code,
3670        "    /// committed once at the end, avoiding per-operation synchronization."
3671    )?;
3672    writeln!(
3673        code,
3674        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3675    )?;
3676    writeln!(
3677        code,
3678        "        // Wait for any pending prefill command buffer"
3679    )?;
3680    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3681    writeln!(code, "            prev.wait_until_completed();")?;
3682    writeln!(code, "        }}")?;
3683    writeln!(code)?;
3684    writeln!(code, "        let pos = self.pos;")?;
3685    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3686    writeln!(code)?;
3687
3688    // Single compute encoder for the entire forward pass — no blit encoder
3689    // transitions. Copy operations use compute copy kernels instead of blits.
3690    let matmul_fn = if is_q8 {
3691        "dispatch_matmul_q8"
3692    } else if is_q4 {
3693        "dispatch_matmul_q4"
3694    } else {
3695        "dispatch_matmul"
3696    };
3697
3698    writeln!(
3699        code,
3700        "        // Single compute encoder for the entire forward pass (no blit transitions)"
3701    )?;
3702    writeln!(code, "        {{")?;
3703    writeln!(
3704        code,
3705        "            let enc = cmd.new_compute_command_encoder();"
3706    )?;
3707    writeln!(code)?;
3708
3709    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
3710    writeln!(
3711        code,
3712        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3713    )?;
3714    writeln!(
3715        code,
3716        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3717    )?;
3718    writeln!(
3719        code,
3720        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3721    )?;
3722    writeln!(
3723        code,
3724        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
3725        hidden * 4,
3726    )?;
3727    writeln!(code, "            unsafe {{")?;
3728    writeln!(
3729        code,
3730        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
3731    )?;
3732    writeln!(
3733        code,
3734        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3735    )?;
3736    writeln!(
3737        code,
3738        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
3739    )?;
3740    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
3741    writeln!(
3742        code,
3743        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3744    )?;
3745    writeln!(code, "                    hidden_ptr,")?;
3746    writeln!(code, "                    HIDDEN_SIZE,")?;
3747    writeln!(code, "                );")?;
3748    writeln!(
3749        code,
3750        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3751    )?;
3752    writeln!(code, "            }}")?;
3753    writeln!(code)?;
3754
3755    // 2. Transformer layers
3756    writeln!(code, "            // 2. Transformer layers")?;
3757    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3758    writeln!(code)?;
3759    let q_byte_offset = 0usize;
3760    let k_byte_offset = hidden * 4;
3761    let v_byte_offset = (hidden + kv_dim) * 4;
3762
3763    writeln!(
3764        code,
3765        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
3766    )?;
3767    writeln!(
3768        code,
3769        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3770    )?;
3771    writeln!(
3772        code,
3773        "                // Fused Q+K+V matmul: single dispatch for all three projections"
3774    )?;
3775    writeln!(
3776        code,
3777        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3778    )?;
3779    if config.qkv_bias {
3780        writeln!(
3781            code,
3782            "                // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
3783        )?;
3784        writeln!(
3785            code,
3786            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
3787        )?;
3788    }
3789    writeln!(
3790        code,
3791        "                // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
3792    )?;
3793    writeln!(
3794        code,
3795        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3796    )?;
3797    writeln!(
3798        code,
3799        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3800    )?;
3801    writeln!(code)?;
3802    writeln!(
3803        code,
3804        "                // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
3805    )?;
3806    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
3807    writeln!(
3808        code,
3809        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3810    )?;
3811    writeln!(
3812        code,
3813        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3814    )?;
3815    writeln!(code)?;
3816    writeln!(
3817        code,
3818        "                // Attention using Q from qkv_buf (offset 0)"
3819    )?;
3820    writeln!(
3821        code,
3822        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3823    )?;
3824    writeln!(
3825        code,
3826        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3827    )?;
3828    writeln!(
3829        code,
3830        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3831    )?;
3832    writeln!(
3833        code,
3834        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
3835    )?;
3836    writeln!(
3837        code,
3838        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3839    )?;
3840    writeln!(
3841        code,
3842        "                // Fused gate+up matmul: single dispatch for both projections"
3843    )?;
3844    writeln!(
3845        code,
3846        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3847    )?;
3848    writeln!(
3849        code,
3850        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3851    )?;
3852    writeln!(
3853        code,
3854        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3855    )?;
3856    writeln!(
3857        code,
3858        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3859    )?;
3860    writeln!(code, "            }}")?;
3861    writeln!(code)?;
3862
3863    // 3. Final RMS norm + logits
3864    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
3865    writeln!(
3866        code,
3867        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
3868    )?;
3869    writeln!(
3870        code,
3871        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3872    )?;
3873    writeln!(code)?;
3874    writeln!(code, "            enc.end_encoding();")?;
3875    writeln!(code, "        }}")?;
3876    writeln!(code)?;
3877
3878    // 5. Single commit + wait, then read back logits
3879    writeln!(
3880        code,
3881        "        // 5. Commit all GPU work and wait for completion"
3882    )?;
3883    writeln!(code, "        cmd.commit();")?;
3884    writeln!(code, "        cmd.wait_until_completed();")?;
3885    writeln!(code)?;
3886    writeln!(code, "        // 6. Read back logits from GPU")?;
3887    writeln!(code, "        let logits = unsafe {{")?;
3888    writeln!(
3889        code,
3890        "            let ptr = self.logits_buf.contents() as *const f32;"
3891    )?;
3892    writeln!(
3893        code,
3894        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3895    )?;
3896    writeln!(code, "        }};")?;
3897    writeln!(code)?;
3898    writeln!(code, "        self.pos += 1;")?;
3899    writeln!(code, "        logits")?;
3900    writeln!(code, "    }}")?;
3901    writeln!(code)?;
3902
3903    // ── forward_profile: instrumented forward with per-operation timing ──
3904    writeln!(
3905        code,
3906        "    /// Profiling forward pass that prints per-stage GPU timing."
3907    )?;
3908    writeln!(code, "    ///")?;
3909    writeln!(
3910        code,
3911        "    /// Each stage is committed and waited on separately so that GPU timestamps"
3912    )?;
3913    writeln!(
3914        code,
3915        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
3916    )?;
3917    writeln!(
3918        code,
3919        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
3920    )?;
3921    writeln!(
3922        code,
3923        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
3924    )?;
3925    writeln!(code, "        use std::time::Instant;")?;
3926    writeln!(code)?;
3927    writeln!(
3928        code,
3929        "        // Wait for any pending prefill command buffer"
3930    )?;
3931    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3932    writeln!(code, "            prev.wait_until_completed();")?;
3933    writeln!(code, "        }}")?;
3934    writeln!(code)?;
3935    writeln!(code, "        let pos = self.pos;")?;
3936    writeln!(code)?;
3937
3938    // Stage: embedding (CPU, no GPU)
3939    writeln!(
3940        code,
3941        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
3942    )?;
3943    writeln!(code, "        let t_embed = Instant::now();")?;
3944    writeln!(code, "        unsafe {{")?;
3945    writeln!(
3946        code,
3947        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
3948    )?;
3949    writeln!(
3950        code,
3951        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3952    )?;
3953    writeln!(
3954        code,
3955        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
3956    )?;
3957    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
3958    writeln!(
3959        code,
3960        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3961    )?;
3962    writeln!(code, "                hidden_ptr,")?;
3963    writeln!(code, "                HIDDEN_SIZE,")?;
3964    writeln!(code, "            );")?;
3965    writeln!(
3966        code,
3967        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3968    )?;
3969    writeln!(code, "        }}")?;
3970    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
3971    writeln!(code)?;
3972
3973    // Stage: Transformer layers (all together on GPU)
3974    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
3975    writeln!(code, "        let t_layers = Instant::now();")?;
3976    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3977    writeln!(code, "        {{")?;
3978    writeln!(
3979        code,
3980        "            let enc = cmd.new_compute_command_encoder();"
3981    )?;
3982    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3983    writeln!(
3984        code,
3985        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3986    )?;
3987    writeln!(
3988        code,
3989        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3990    )?;
3991    if config.qkv_bias {
3992        writeln!(
3993            code,
3994            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
3995        )?;
3996    }
3997    writeln!(
3998        code,
3999        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4000    )?;
4001    writeln!(
4002        code,
4003        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4004    )?;
4005    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
4006    writeln!(
4007        code,
4008        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4009    )?;
4010    writeln!(
4011        code,
4012        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4013    )?;
4014    writeln!(
4015        code,
4016        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4017    )?;
4018    writeln!(
4019        code,
4020        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4021    )?;
4022    writeln!(
4023        code,
4024        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4025    )?;
4026    writeln!(
4027        code,
4028        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4029    )?;
4030    writeln!(
4031        code,
4032        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4033    )?;
4034    writeln!(
4035        code,
4036        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4037    )?;
4038    writeln!(
4039        code,
4040        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4041    )?;
4042    writeln!(
4043        code,
4044        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4045    )?;
4046    writeln!(code, "            }}")?;
4047    writeln!(code, "            enc.end_encoding();")?;
4048    writeln!(code, "        }}")?;
4049    writeln!(code, "        cmd.commit();")?;
4050    writeln!(code, "        cmd.wait_until_completed();")?;
4051    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
4052    writeln!(code)?;
4053
4054    // Stage: Final norm + logits
4055    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
4056    writeln!(code, "        let t_logits = Instant::now();")?;
4057    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
4058    writeln!(code, "        {{")?;
4059    writeln!(
4060        code,
4061        "            let enc = cmd2.new_compute_command_encoder();"
4062    )?;
4063    writeln!(
4064        code,
4065        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4066    )?;
4067    writeln!(
4068        code,
4069        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4070    )?;
4071    writeln!(code, "            enc.end_encoding();")?;
4072    writeln!(code, "        }}")?;
4073    writeln!(code, "        cmd2.commit();")?;
4074    writeln!(code, "        cmd2.wait_until_completed();")?;
4075    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
4076    writeln!(code)?;
4077
4078    // Print profile results
4079    writeln!(
4080        code,
4081        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4082    )?;
4083    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
4084    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
4085    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
4086    writeln!(
4087        code,
4088        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4089    )?;
4090    writeln!(code)?;
4091
4092    // Read back logits
4093    writeln!(code, "        let logits = unsafe {{")?;
4094    writeln!(
4095        code,
4096        "            let ptr = self.logits_buf.contents() as *const f32;"
4097    )?;
4098    writeln!(
4099        code,
4100        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4101    )?;
4102    writeln!(code, "        }};")?;
4103    writeln!(code)?;
4104    writeln!(code, "        self.pos += 1;")?;
4105    writeln!(code, "        logits")?;
4106    writeln!(code, "    }}")?;
4107    writeln!(code)?;
4108
4109    // ── forward_prefill: single-token async forward (backward compat) ──
4110    writeln!(
4111        code,
4112        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
4113    )?;
4114    writeln!(code, "    ///")?;
4115    writeln!(
4116        code,
4117        "    /// Commits the command buffer without waiting, enabling double-buffered"
4118    )?;
4119    writeln!(
4120        code,
4121        "    /// execution: GPU processes token N while CPU encodes token N+1."
4122    )?;
4123    writeln!(
4124        code,
4125        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
4126    )?;
4127    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
4128    writeln!(code, "    }}")?;
4129    writeln!(code)?;
4130
4131    // ── forward_prefill_batch: batched prefill for multiple tokens ──
4132    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
4133    let batch_matmul_fn = if is_q8 {
4134        "dispatch_matmul_q8_batch"
4135    } else if is_q4 {
4136        "dispatch_matmul_q4_batch"
4137    } else {
4138        "dispatch_matmul_batch"
4139    };
4140
4141    writeln!(
4142        code,
4143        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4144    )?;
4145    writeln!(code, "    ///")?;
4146    writeln!(
4147        code,
4148        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4149    )?;
4150    writeln!(
4151        code,
4152        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
4153    )?;
4154    writeln!(
4155        code,
4156        "    /// This provides significant speedup during prompt prefill."
4157    )?;
4158    writeln!(
4159        code,
4160        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4161    )?;
4162    writeln!(code, "        let m = tokens.len().min(MAX_BATCH_SIZE);")?;
4163    writeln!(code, "        if m == 0 {{ return; }}")?;
4164    writeln!(code, "        let start_pos = self.pos;")?;
4165    writeln!(code)?;
4166    writeln!(code, "        // Wait for any pending command buffer")?;
4167    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4168    writeln!(code, "            prev.wait_until_completed();")?;
4169    writeln!(code, "        }}")?;
4170    writeln!(code)?;
4171
4172    // Upload token IDs and positions to GPU
4173    writeln!(
4174        code,
4175        "        // Upload token IDs and positions to GPU buffers"
4176    )?;
4177    writeln!(code, "        unsafe {{")?;
4178    writeln!(
4179        code,
4180        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4181    )?;
4182    writeln!(
4183        code,
4184        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4185    )?;
4186    writeln!(code, "            for i in 0..m {{")?;
4187    writeln!(code, "                *tok_ptr.add(i) = tokens[i];")?;
4188    writeln!(
4189        code,
4190        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
4191    )?;
4192    writeln!(code, "            }}")?;
4193    writeln!(code, "        }}")?;
4194    writeln!(code)?;
4195
4196    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4197    writeln!(code, "        {{")?;
4198    writeln!(
4199        code,
4200        "            let enc = cmd.new_compute_command_encoder();"
4201    )?;
4202    writeln!(code)?;
4203
4204    // 1. Batch embedding lookup
4205    writeln!(
4206        code,
4207        "            // 1. Batch embedding lookup: copy all token embeddings at once"
4208    )?;
4209    writeln!(
4210        code,
4211        "            self.dispatch_copy_embedding_batch(&enc, m);"
4212    )?;
4213    // Copy batch_hidden -> batch_residual
4214    writeln!(
4215        code,
4216        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4217    )?;
4218    writeln!(code)?;
4219
4220    // 2. Transformer layers
4221    writeln!(code, "            // 2. Transformer layers")?;
4222    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4223    writeln!(code)?;
4224
4225    // Batch RMS norm: residual -> hidden (batched)
4226    writeln!(
4227        code,
4228        "                // Batch RMS norm: batch_residual -> batch_hidden"
4229    )?;
4230    writeln!(
4231        code,
4232        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4233    )?;
4234
4235    // Batch QKV matmul
4236    writeln!(
4237        code,
4238        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4239    )?;
4240    writeln!(
4241        code,
4242        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4243    )?;
4244    if config.qkv_bias {
4245        writeln!(
4246            code,
4247            "                // Qwen2: broadcast-add QKV bias across all M tokens."
4248        )?;
4249        writeln!(
4250            code,
4251            "                self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4252        )?;
4253    }
4254    writeln!(code)?;
4255
4256    // Fused RoPE on Q+K portions in a single dispatch
4257    let k_float_offset = hidden;
4258    writeln!(
4259        code,
4260        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4261    )?;
4262    writeln!(
4263        code,
4264        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4265    )?;
4266    writeln!(code)?;
4267
4268    // Fused KV cache update: copy both K and V in a single dispatch
4269    let v_float_offset = hidden + kv_dim;
4270    writeln!(
4271        code,
4272        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4273    )?;
4274    writeln!(
4275        code,
4276        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4277    )?;
4278    writeln!(code)?;
4279
4280    // Batched causal attention: ONE dispatch for all M tokens
4281    writeln!(
4282        code,
4283        "                // Batched causal attention: one dispatch for all M tokens"
4284    )?;
4285    writeln!(
4286        code,
4287        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4288    )?;
4289    writeln!(code)?;
4290
4291    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4292    writeln!(code, "                // Batched O projection")?;
4293    writeln!(
4294        code,
4295        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4296    )?;
4297    writeln!(code)?;
4298
4299    // Batch add: residual += attn_proj for all tokens
4300    writeln!(
4301        code,
4302        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4303    )?;
4304    writeln!(code)?;
4305
4306    // Batch FFN
4307    writeln!(
4308        code,
4309        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4310    )?;
4311    writeln!(
4312        code,
4313        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4314    )?;
4315    writeln!(
4316        code,
4317        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4318    )?;
4319    writeln!(
4320        code,
4321        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4322    )?;
4323    writeln!(
4324        code,
4325        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4326    )?;
4327    writeln!(
4328        code,
4329        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4330    )?;
4331    writeln!(code, "            }}")?;
4332    writeln!(code)?;
4333
4334    // Copy last token's residual to single-token residual_buf for next forward() call
4335    writeln!(
4336        code,
4337        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4338    )?;
4339    writeln!(
4340        code,
4341        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4342    )?;
4343    writeln!(code)?;
4344    writeln!(code, "            enc.end_encoding();")?;
4345    writeln!(code, "        }}")?;
4346    writeln!(code)?;
4347
4348    writeln!(code, "        cmd.commit();")?;
4349    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4350    writeln!(code, "        self.pos += m;")?;
4351    writeln!(code, "    }}")?;
4352    writeln!(code)?;
4353
4354    // ── reset() — rewind KV cache position for new inference requests ──
4355    writeln!(
4356        code,
4357        "    /// Reset the model state for a new inference request."
4358    )?;
4359    writeln!(code, "    pub fn reset(&mut self) {{")?;
4360    writeln!(code, "        self.pos = 0;")?;
4361    writeln!(code, "        self.prev_cmd = None;")?;
4362    writeln!(code, "    }}")?;
4363    writeln!(code)?;
4364
4365    // ── Private dispatch helpers (all take a shared compute encoder) ──
4366    writeln!(
4367        code,
4368        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4369    )?;
4370    writeln!(
4371        code,
4372        "    // These methods set pipeline state + buffers + dispatch on an existing"
4373    )?;
4374    writeln!(
4375        code,
4376        "    // encoder, avoiding per-operation encoder creation overhead."
4377    )?;
4378    writeln!(code)?;
4379
4380    // dispatch_rms_norm
4381    writeln!(
4382        code,
4383        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4384    )?;
4385    writeln!(
4386        code,
4387        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4388    )?;
4389    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4390    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4391    writeln!(
4392        code,
4393        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4394    )?;
4395    writeln!(
4396        code,
4397        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4398    )?;
4399    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4400    writeln!(
4401        code,
4402        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4403    )?;
4404    writeln!(
4405        code,
4406        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4407    )?;
4408    writeln!(
4409        code,
4410        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4411    )?;
4412    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4413    writeln!(
4414        code,
4415        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4416    )?;
4417    writeln!(
4418        code,
4419        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4420    )?;
4421    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4422    writeln!(code, "    }}")?;
4423    writeln!(code)?;
4424
4425    // dispatch_matmul
4426    writeln!(
4427        code,
4428        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4429    )?;
4430    writeln!(
4431        code,
4432        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4433    )?;
4434    writeln!(code, "        let r: u32 = rows as u32;")?;
4435    writeln!(code, "        let c: u32 = cols as u32;")?;
4436    writeln!(
4437        code,
4438        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4439    )?;
4440    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4441    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4442    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4443    writeln!(
4444        code,
4445        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4446    )?;
4447    writeln!(
4448        code,
4449        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4450    )?;
4451    writeln!(
4452        code,
4453        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4454    )?;
4455    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4456    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4457    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4458    writeln!(
4459        code,
4460        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4461    )?;
4462    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4463    writeln!(code, "    }}")?;
4464    writeln!(code)?;
4465
4466    // dispatch_matmul_q8
4467    writeln!(
4468        code,
4469        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4470    )?;
4471    writeln!(
4472        code,
4473        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4474    )?;
4475    writeln!(
4476        code,
4477        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4478    )?;
4479    writeln!(code, "        let r: u32 = rows as u32;")?;
4480    writeln!(code, "        let c: u32 = cols as u32;")?;
4481    writeln!(
4482        code,
4483        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4484    )?;
4485    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4486    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4487    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4488    writeln!(
4489        code,
4490        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4491    )?;
4492    writeln!(
4493        code,
4494        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4495    )?;
4496    writeln!(
4497        code,
4498        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4499    )?;
4500    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4501    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4502    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4503    writeln!(
4504        code,
4505        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4506    )?;
4507    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4508    writeln!(code, "    }}")?;
4509    writeln!(code)?;
4510
4511    // dispatch_matmul_q4
4512    writeln!(
4513        code,
4514        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4515    )?;
4516    writeln!(
4517        code,
4518        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4519    )?;
4520    writeln!(
4521        code,
4522        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4523    )?;
4524    writeln!(code, "        let r: u32 = rows as u32;")?;
4525    writeln!(code, "        let c: u32 = cols as u32;")?;
4526    writeln!(
4527        code,
4528        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4529    )?;
4530    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4531    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4532    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4533    writeln!(
4534        code,
4535        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4536    )?;
4537    writeln!(
4538        code,
4539        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4540    )?;
4541    writeln!(
4542        code,
4543        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4544    )?;
4545    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4546    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4547    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4548    writeln!(
4549        code,
4550        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4551    )?;
4552    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4553    writeln!(code, "    }}")?;
4554    writeln!(code)?;
4555
4556    // dispatch_rope
4557    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4558    writeln!(
4559        code,
4560        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4561    )?;
4562    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4563    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4564    writeln!(code, "        let p: u32 = pos as u32;")?;
4565    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4566    writeln!(
4567        code,
4568        "        let total_pairs = num_heads * (head_dim / 2);"
4569    )?;
4570    writeln!(
4571        code,
4572        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4573    )?;
4574    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4575    writeln!(
4576        code,
4577        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4578    )?;
4579    writeln!(
4580        code,
4581        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4582    )?;
4583    writeln!(
4584        code,
4585        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4586    )?;
4587    writeln!(
4588        code,
4589        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4590    )?;
4591    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4592    writeln!(
4593        code,
4594        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4595    )?;
4596    writeln!(
4597        code,
4598        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4599    )?;
4600    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4601    writeln!(code, "    }}")?;
4602    writeln!(code)?;
4603
4604    // dispatch_rope_offset
4605    writeln!(
4606        code,
4607        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4608    )?;
4609    writeln!(
4610        code,
4611        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4612    )?;
4613    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4614    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4615    writeln!(code, "        let p: u32 = pos as u32;")?;
4616    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4617    writeln!(
4618        code,
4619        "        let total_pairs = num_heads * (head_dim / 2);"
4620    )?;
4621    writeln!(
4622        code,
4623        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4624    )?;
4625    writeln!(
4626        code,
4627        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4628    )?;
4629    writeln!(
4630        code,
4631        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4632    )?;
4633    writeln!(
4634        code,
4635        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4636    )?;
4637    writeln!(
4638        code,
4639        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4640    )?;
4641    writeln!(
4642        code,
4643        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4644    )?;
4645    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4646    writeln!(
4647        code,
4648        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4649    )?;
4650    writeln!(
4651        code,
4652        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4653    )?;
4654    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4655    writeln!(code, "    }}")?;
4656    writeln!(code)?;
4657
4658    // dispatch_attention
4659    writeln!(
4660        code,
4661        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4662    )?;
4663    writeln!(
4664        code,
4665        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4666    )?;
4667    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4668    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4669    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4670    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4671    writeln!(
4672        code,
4673        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4674    )?;
4675    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4676    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4677    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4678    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4679    writeln!(
4680        code,
4681        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4682    )?;
4683    writeln!(
4684        code,
4685        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4686    )?;
4687    writeln!(
4688        code,
4689        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4690    )?;
4691    writeln!(
4692        code,
4693        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4694    )?;
4695    writeln!(
4696        code,
4697        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4698    )?;
4699    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4700    writeln!(
4701        code,
4702        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4703    )?;
4704    writeln!(
4705        code,
4706        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4707    )?;
4708    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4709    writeln!(code, "    }}")?;
4710    writeln!(code)?;
4711
4712    // dispatch_attention_offset
4713    writeln!(
4714        code,
4715        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4716    )?;
4717    writeln!(
4718        code,
4719        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4720    )?;
4721    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4722    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4723    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4724    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4725    writeln!(
4726        code,
4727        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4728    )?;
4729    writeln!(
4730        code,
4731        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
4732    )?;
4733    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4734    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4735    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4736    writeln!(
4737        code,
4738        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4739    )?;
4740    writeln!(
4741        code,
4742        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4743    )?;
4744    writeln!(
4745        code,
4746        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4747    )?;
4748    writeln!(
4749        code,
4750        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4751    )?;
4752    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4753    writeln!(
4754        code,
4755        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4756    )?;
4757    writeln!(
4758        code,
4759        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4760    )?;
4761    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4762    writeln!(code, "    }}")?;
4763    writeln!(code)?;
4764
4765    // dispatch_silu_mul
4766    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
4767    writeln!(
4768        code,
4769        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
4770    )?;
4771    writeln!(code, "        let count: u32 = n as u32;")?;
4772    writeln!(
4773        code,
4774        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
4775    )?;
4776    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
4777    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
4778    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4779    writeln!(
4780        code,
4781        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4782    )?;
4783    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4784    writeln!(
4785        code,
4786        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4787    )?;
4788    writeln!(
4789        code,
4790        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4791    )?;
4792    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4793    writeln!(code, "    }}")?;
4794    writeln!(code)?;
4795
4796    // dispatch_silu_mul_fused
4797    writeln!(
4798        code,
4799        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
4800    )?;
4801    writeln!(
4802        code,
4803        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
4804    )?;
4805    writeln!(
4806        code,
4807        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
4808    )?;
4809    writeln!(code, "        let count: u32 = n as u32;")?;
4810    writeln!(
4811        code,
4812        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
4813    )?;
4814    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
4815    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
4816    writeln!(
4817        code,
4818        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4819    )?;
4820    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4821    writeln!(
4822        code,
4823        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4824    )?;
4825    writeln!(
4826        code,
4827        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4828    )?;
4829    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4830    writeln!(code, "    }}")?;
4831    writeln!(code)?;
4832
4833    // dispatch_copy (simple src -> dst copy via compute kernel)
4834    writeln!(
4835        code,
4836        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
4837    )?;
4838    writeln!(
4839        code,
4840        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
4841    )?;
4842    writeln!(code, "        let n: u32 = count as u32;")?;
4843    writeln!(
4844        code,
4845        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4846    )?;
4847    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4848    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4849    writeln!(
4850        code,
4851        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4852    )?;
4853    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4854    writeln!(
4855        code,
4856        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4857    )?;
4858    writeln!(
4859        code,
4860        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4861    )?;
4862    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4863    writeln!(code, "    }}")?;
4864    writeln!(code)?;
4865
4866    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
4867    writeln!(
4868        code,
4869        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
4870    )?;
4871    writeln!(
4872        code,
4873        "    /// Used for embedding table lookup (copy a specific row)."
4874    )?;
4875    writeln!(
4876        code,
4877        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
4878    )?;
4879    writeln!(code, "        let off: u32 = src_offset as u32;")?;
4880    writeln!(code, "        let n: u32 = count as u32;")?;
4881    writeln!(
4882        code,
4883        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
4884    )?;
4885    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4886    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4887    writeln!(
4888        code,
4889        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
4890    )?;
4891    writeln!(
4892        code,
4893        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4894    )?;
4895    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4896    writeln!(
4897        code,
4898        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4899    )?;
4900    writeln!(
4901        code,
4902        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4903    )?;
4904    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4905    writeln!(code, "    }}")?;
4906    writeln!(code)?;
4907
4908    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
4909    writeln!(
4910        code,
4911        "    /// Dispatch copy from source at byte offset to destination at float offset."
4912    )?;
4913    writeln!(
4914        code,
4915        "    /// Used for KV cache updates from fused QKV buffer."
4916    )?;
4917    writeln!(
4918        code,
4919        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4920    )?;
4921    writeln!(code, "        let n: u32 = count as u32;")?;
4922    writeln!(
4923        code,
4924        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4925    )?;
4926    writeln!(
4927        code,
4928        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
4929    )?;
4930    writeln!(
4931        code,
4932        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4933    )?;
4934    writeln!(
4935        code,
4936        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4937    )?;
4938    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4939    writeln!(
4940        code,
4941        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4942    )?;
4943    writeln!(
4944        code,
4945        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4946    )?;
4947    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4948    writeln!(code, "    }}")?;
4949    writeln!(code)?;
4950
4951    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
4952    writeln!(
4953        code,
4954        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
4955    )?;
4956    writeln!(
4957        code,
4958        "    /// Used for KV cache updates (write to a specific position in the cache)."
4959    )?;
4960    writeln!(
4961        code,
4962        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
4963    )?;
4964    writeln!(code, "        let n: u32 = count as u32;")?;
4965    writeln!(
4966        code,
4967        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4968    )?;
4969    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4970    writeln!(
4971        code,
4972        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
4973    )?;
4974    writeln!(
4975        code,
4976        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4977    )?;
4978    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4979    writeln!(
4980        code,
4981        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4982    )?;
4983    writeln!(
4984        code,
4985        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4986    )?;
4987    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4988    writeln!(code, "    }}")?;
4989    writeln!(code)?;
4990
4991    // dispatch_add_inplace (residual connection, no blit needed)
4992    writeln!(
4993        code,
4994        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
4995    )?;
4996    writeln!(
4997        code,
4998        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
4999    )?;
5000    writeln!(code, "        let count: u32 = n as u32;")?;
5001    writeln!(
5002        code,
5003        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5004    )?;
5005    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5006    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5007    writeln!(
5008        code,
5009        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5010    )?;
5011    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5012    writeln!(
5013        code,
5014        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5015    )?;
5016    writeln!(
5017        code,
5018        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5019    )?;
5020    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5021    writeln!(code, "    }}")?;
5022    writeln!(code)?;
5023
5024    // ── Batched prefill dispatch helpers ──
5025    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
5026    writeln!(code)?;
5027
5028    // dispatch_copy_embedding_batch
5029    writeln!(
5030        code,
5031        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
5032    )?;
5033    writeln!(
5034        code,
5035        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5036    )?;
5037    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
5038    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5039    writeln!(
5040        code,
5041        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5042    )?;
5043    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5044    writeln!(
5045        code,
5046        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5047    )?;
5048    writeln!(
5049        code,
5050        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5051    )?;
5052    writeln!(
5053        code,
5054        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5055    )?;
5056    writeln!(
5057        code,
5058        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5059    )?;
5060    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
5061    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5062    writeln!(
5063        code,
5064        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5065    )?;
5066    writeln!(
5067        code,
5068        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5069    )?;
5070    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5071    writeln!(code, "    }}")?;
5072    writeln!(code)?;
5073
5074    // dispatch_rms_norm_batch
5075    writeln!(
5076        code,
5077        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
5078    )?;
5079    writeln!(
5080        code,
5081        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5082    )?;
5083    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
5084    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
5085    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5086    writeln!(
5087        code,
5088        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5089    )?;
5090    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
5091    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
5092    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5093    writeln!(
5094        code,
5095        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5096    )?;
5097    writeln!(
5098        code,
5099        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5100    )?;
5101    writeln!(
5102        code,
5103        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5104    )?;
5105    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5106    writeln!(
5107        code,
5108        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5109    )?;
5110    writeln!(
5111        code,
5112        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5113    )?;
5114    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5115    writeln!(code, "    }}")?;
5116    writeln!(code)?;
5117
5118    // dispatch_matmul_batch (f32)
5119    writeln!(
5120        code,
5121        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5122    )?;
5123    writeln!(
5124        code,
5125        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5126    )?;
5127    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5128    writeln!(code, "        let r: u32 = rows as u32;")?;
5129    writeln!(code, "        let c: u32 = cols as u32;")?;
5130    writeln!(
5131        code,
5132        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5133    )?;
5134    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5135    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5136    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5137    writeln!(
5138        code,
5139        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5140    )?;
5141    writeln!(
5142        code,
5143        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5144    )?;
5145    writeln!(
5146        code,
5147        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5148    )?;
5149    writeln!(
5150        code,
5151        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
5152    )?;
5153    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5154    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5155    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5156    writeln!(
5157        code,
5158        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5159    )?;
5160    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5161    writeln!(code, "    }}")?;
5162    writeln!(code)?;
5163
5164    // dispatch_matmul_q8_batch
5165    writeln!(
5166        code,
5167        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5168    )?;
5169    writeln!(code, "    ///")?;
5170    writeln!(
5171        code,
5172        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5173    )?;
5174    writeln!(
5175        code,
5176        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5177    )?;
5178    writeln!(
5179        code,
5180        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5181    )?;
5182    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5183    writeln!(code, "        let r: u32 = rows as u32;")?;
5184    writeln!(code, "        let c: u32 = cols as u32;")?;
5185    writeln!(
5186        code,
5187        "        // Tile sizes must match the Metal shader constants."
5188    )?;
5189    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
5190    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
5191    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
5192    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
5193    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
5194    writeln!(
5195        code,
5196        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
5197    )?;
5198    writeln!(
5199        code,
5200        "        // Prefer the large 32×32 tile when the problem supports it — halves"
5201    )?;
5202    writeln!(
5203        code,
5204        "        // dispatch count and reuses each weight load across 32 tokens."
5205    )?;
5206    writeln!(
5207        code,
5208        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5209    )?;
5210    writeln!(
5211        code,
5212        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5213    )?;
5214    writeln!(
5215        code,
5216        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
5217    )?;
5218    writeln!(
5219        code,
5220        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5221    )?;
5222    writeln!(
5223        code,
5224        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
5225    )?;
5226    writeln!(
5227        code,
5228        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5229    )?;
5230    writeln!(
5231        code,
5232        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5233    )?;
5234    writeln!(
5235        code,
5236        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5237    )?;
5238    writeln!(
5239        code,
5240        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5241    )?;
5242    writeln!(code, "            let use_h4 = cols >= 2048;")?;
5243    writeln!(code, "            let pipe = if use_h4 {{")?;
5244    writeln!(code, "                if num_tokens >= 256 {{")?;
5245    writeln!(code, "                    &self.matmul_q8_mma32_hh4_pipeline")?;
5246    writeln!(code, "                }} else {{")?;
5247    writeln!(code, "                    &self.matmul_q8_mma32_h4_pipeline")?;
5248    writeln!(code, "                }}")?;
5249    writeln!(code, "            }} else {{")?;
5250    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
5251    writeln!(code, "            }};")?;
5252    writeln!(
5253        code,
5254        "            enc.set_compute_pipeline_state(pipe);"
5255    )?;
5256    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5257    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5258    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5259    writeln!(
5260        code,
5261        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5262    )?;
5263    writeln!(
5264        code,
5265        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5266    )?;
5267    writeln!(
5268        code,
5269        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5270    )?;
5271    writeln!(
5272        code,
5273        "            let row_tgs = rows / MMA32_ROW_TILE;"
5274    )?;
5275    writeln!(
5276        code,
5277        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5278    )?;
5279    writeln!(
5280        code,
5281        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5282    )?;
5283    writeln!(
5284        code,
5285        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5286    )?;
5287    writeln!(
5288        code,
5289        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5290    )?;
5291    writeln!(
5292        code,
5293        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5294    )?;
5295    writeln!(
5296        code,
5297        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5298    )?;
5299    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5300    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5301    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5302    writeln!(
5303        code,
5304        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5305    )?;
5306    writeln!(
5307        code,
5308        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5309    )?;
5310    writeln!(
5311        code,
5312        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5313    )?;
5314    writeln!(
5315        code,
5316        "            let row_tgs = rows / MMA_ROW_TILE;"
5317    )?;
5318    writeln!(
5319        code,
5320        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5321    )?;
5322    writeln!(
5323        code,
5324        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5325    )?;
5326    writeln!(
5327        code,
5328        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5329    )?;
5330    writeln!(
5331        code,
5332        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5333    )?;
5334    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5335    writeln!(
5336        code,
5337        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5338    )?;
5339    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5340    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5341    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5342    writeln!(
5343        code,
5344        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5345    )?;
5346    writeln!(
5347        code,
5348        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5349    )?;
5350    writeln!(
5351        code,
5352        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5353    )?;
5354    writeln!(
5355        code,
5356        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5357    )?;
5358    writeln!(
5359        code,
5360        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5361    )?;
5362    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5363    writeln!(
5364        code,
5365        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5366    )?;
5367    writeln!(
5368        code,
5369        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5370    )?;
5371    writeln!(code, "        }} else {{")?;
5372    writeln!(
5373        code,
5374        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5375    )?;
5376    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5377    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5378    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5379    writeln!(
5380        code,
5381        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5382    )?;
5383    writeln!(
5384        code,
5385        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5386    )?;
5387    writeln!(
5388        code,
5389        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5390    )?;
5391    writeln!(
5392        code,
5393        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5394    )?;
5395    writeln!(
5396        code,
5397        "            let num_tg = (row_tgs * num_tokens) as u64;"
5398    )?;
5399    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5400    writeln!(
5401        code,
5402        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5403    )?;
5404    writeln!(
5405        code,
5406        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5407    )?;
5408    writeln!(code, "        }}")?;
5409    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5410    writeln!(code, "    }}")?;
5411    writeln!(code)?;
5412
5413    // dispatch_matmul_q4_batch
5414    writeln!(
5415        code,
5416        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5417    )?;
5418    writeln!(
5419        code,
5420        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5421    )?;
5422    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5423    writeln!(code, "        let r: u32 = rows as u32;")?;
5424    writeln!(code, "        let c: u32 = cols as u32;")?;
5425    writeln!(
5426        code,
5427        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5428    )?;
5429    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5430    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5431    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5432    writeln!(
5433        code,
5434        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5435    )?;
5436    writeln!(
5437        code,
5438        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5439    )?;
5440    writeln!(
5441        code,
5442        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5443    )?;
5444    writeln!(
5445        code,
5446        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5447    )?;
5448    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5449    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5450    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5451    writeln!(
5452        code,
5453        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5454    )?;
5455    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5456    writeln!(code, "    }}")?;
5457    writeln!(code)?;
5458
5459    // dispatch_add_bias_batch — Qwen2 QKV bias broadcast-add after fused qkv matmul.
5460    if config.qkv_bias {
5461        writeln!(
5462            code,
5463            "    /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5464        )?;
5465        writeln!(
5466            code,
5467            "    fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5468        )?;
5469        writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5470        writeln!(code, "        let r: u32 = rows as u32;")?;
5471        writeln!(
5472            code,
5473            "        enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5474        )?;
5475        writeln!(code, "        enc.set_buffer(0, Some(out), 0);")?;
5476        writeln!(code, "        enc.set_buffer(1, Some(bias), 0);")?;
5477        writeln!(
5478            code,
5479            "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5480        )?;
5481        writeln!(
5482            code,
5483            "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5484        )?;
5485        writeln!(code, "        let total = (num_tokens * rows) as u64;")?;
5486        writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5487        writeln!(
5488            code,
5489            "        let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5490        )?;
5491        writeln!(
5492            code,
5493            "        enc.dispatch_thread_groups(grid_size, tg_size);"
5494        )?;
5495        writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5496        writeln!(code, "    }}")?;
5497        writeln!(code)?;
5498    }
5499
5500    // dispatch_rope_batch
5501    writeln!(
5502        code,
5503        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5504    )?;
5505    writeln!(
5506        code,
5507        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5508    )?;
5509    writeln!(
5510        code,
5511        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5512    )?;
5513    writeln!(
5514        code,
5515        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5516    )?;
5517    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5518    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5519    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5520    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5521    writeln!(
5522        code,
5523        "        let pairs_per_token = num_heads * (head_dim / 2);"
5524    )?;
5525    writeln!(
5526        code,
5527        "        let total_pairs = num_tokens * pairs_per_token;"
5528    )?;
5529    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5530    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5531    // we need to pass the buffer at the right byte offset for each token's data.
5532    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5533    // but our layout is data[token * row_stride + data_float_offset + ...].
5534    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5535    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5536    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5537    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5538    //
5539    // Simplest approach: use the single-token rope kernel for each token in a loop.
5540    // This is still efficient because we're dispatching all within the same command encoder.
5541    writeln!(
5542        code,
5543        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5544    )?;
5545    writeln!(code, "        for t in 0..num_tokens {{")?;
5546    writeln!(
5547        code,
5548        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5549    )?;
5550    writeln!(
5551        code,
5552        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5553    )?;
5554    writeln!(
5555        code,
5556        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5557    )?;
5558    writeln!(
5559        code,
5560        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5561    )?;
5562    writeln!(
5563        code,
5564        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5565    )?;
5566    writeln!(
5567        code,
5568        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5569    )?;
5570    writeln!(
5571        code,
5572        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5573    )?;
5574    writeln!(
5575        code,
5576        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5577    )?;
5578    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5579    writeln!(
5580        code,
5581        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5582    )?;
5583    writeln!(
5584        code,
5585        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5586    )?;
5587    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5588    writeln!(code, "        }}")?;
5589    writeln!(code, "    }}")?;
5590    writeln!(code)?;
5591
5592    // dispatch_silu_mul_fused_batch
5593    writeln!(
5594        code,
5595        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5596    )?;
5597    writeln!(
5598        code,
5599        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5600    )?;
5601    writeln!(code, "        let count: u32 = n as u32;")?;
5602    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5603    writeln!(
5604        code,
5605        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5606    )?;
5607    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5608    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5609    writeln!(
5610        code,
5611        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5612    )?;
5613    writeln!(
5614        code,
5615        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5616    )?;
5617    writeln!(code, "        let total = num_tokens * n;")?;
5618    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5619    writeln!(
5620        code,
5621        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5622    )?;
5623    writeln!(
5624        code,
5625        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5626    )?;
5627    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5628    writeln!(code, "    }}")?;
5629    writeln!(code)?;
5630
5631    // dispatch_add_inplace_batch_n (add n elements in-place)
5632    writeln!(
5633        code,
5634        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5635    )?;
5636    writeln!(
5637        code,
5638        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5639    )?;
5640    writeln!(code, "        let count: u32 = total_n as u32;")?;
5641    writeln!(
5642        code,
5643        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5644    )?;
5645    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5646    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5647    writeln!(
5648        code,
5649        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5650    )?;
5651    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5652    writeln!(
5653        code,
5654        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5655    )?;
5656    writeln!(
5657        code,
5658        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5659    )?;
5660    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5661    writeln!(code, "    }}")?;
5662    writeln!(code)?;
5663
5664    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
5665    writeln!(
5666        code,
5667        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
5668    )?;
5669    writeln!(
5670        code,
5671        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5672    )?;
5673    writeln!(code, "        let n: u32 = count as u32;")?;
5674    writeln!(
5675        code,
5676        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5677    )?;
5678    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5679    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5680    writeln!(
5681        code,
5682        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5683    )?;
5684    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5685    writeln!(
5686        code,
5687        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5688    )?;
5689    writeln!(
5690        code,
5691        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5692    )?;
5693    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5694    writeln!(code, "    }}")?;
5695    writeln!(code)?;
5696
5697    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
5698    writeln!(
5699        code,
5700        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
5701    )?;
5702    writeln!(
5703        code,
5704        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5705    )?;
5706    writeln!(code, "        let n: u32 = count as u32;")?;
5707    writeln!(
5708        code,
5709        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5710    )?;
5711    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5712    writeln!(
5713        code,
5714        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5715    )?;
5716    writeln!(
5717        code,
5718        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5719    )?;
5720    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5721    writeln!(
5722        code,
5723        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5724    )?;
5725    writeln!(
5726        code,
5727        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5728    )?;
5729    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5730    writeln!(code, "    }}")?;
5731    writeln!(code)?;
5732
5733    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
5734    writeln!(
5735        code,
5736        "    /// Copy from src at byte offset to dst at float offset."
5737    )?;
5738    writeln!(
5739        code,
5740        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5741    )?;
5742    writeln!(code, "        let n: u32 = count as u32;")?;
5743    writeln!(
5744        code,
5745        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5746    )?;
5747    writeln!(
5748        code,
5749        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5750    )?;
5751    writeln!(
5752        code,
5753        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5754    )?;
5755    writeln!(
5756        code,
5757        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5758    )?;
5759    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5760    writeln!(
5761        code,
5762        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5763    )?;
5764    writeln!(
5765        code,
5766        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5767    )?;
5768    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5769    writeln!(code, "    }}")?;
5770    writeln!(code)?;
5771
5772    // dispatch_copy_kv_batch
5773    writeln!(
5774        code,
5775        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
5776    )?;
5777    writeln!(
5778        code,
5779        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
5780    )?;
5781    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5782    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
5783    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5784    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
5785    writeln!(code, "        let so: u32 = src_offset as u32;")?;
5786    writeln!(
5787        code,
5788        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
5789    )?;
5790    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5791    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5792    writeln!(
5793        code,
5794        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5795    )?;
5796    writeln!(
5797        code,
5798        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
5799    )?;
5800    writeln!(
5801        code,
5802        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5803    )?;
5804    writeln!(
5805        code,
5806        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
5807    )?;
5808    writeln!(
5809        code,
5810        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
5811    )?;
5812    writeln!(code, "        let total = num_tokens * kv_dim;")?;
5813    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5814    writeln!(
5815        code,
5816        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5817    )?;
5818    writeln!(
5819        code,
5820        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5821    )?;
5822    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5823    writeln!(code, "    }}")?;
5824    writeln!(code)?;
5825
5826    // dispatch_attention_batch
5827    writeln!(
5828        code,
5829        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
5830    )?;
5831    writeln!(
5832        code,
5833        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
5834    )?;
5835    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5836    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5837    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5838    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5839    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5840    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
5841    // The legacy attention_batch kernel materializes a scores[2048] array
5842    // in threadgroup memory — fast at short seq_len but silently corrupts
5843    // when `base_pos + num_tokens > 2048`.  Switch to the flash kernel
5844    // (tiled online softmax, O(head_dim) threadgroup memory, no seq_len
5845    // cap) above that boundary.  Below it the legacy kernel is ~15% faster
5846    // due to the flash kernel's per-tile barrier overhead at low seq_len.
5847    writeln!(
5848        code,
5849        "        let max_seq = base_pos + num_tokens;"
5850    )?;
5851    writeln!(
5852        code,
5853        "        let pipe = if max_seq > 2000 {{"
5854    )?;
5855    writeln!(
5856        code,
5857        "            &self.attention_flash_batch_pipeline"
5858    )?;
5859    writeln!(code, "        }} else {{")?;
5860    writeln!(
5861        code,
5862        "            &self.attention_batch_pipeline"
5863    )?;
5864    writeln!(code, "        }};")?;
5865    writeln!(
5866        code,
5867        "        enc.set_compute_pipeline_state(pipe);"
5868    )?;
5869    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
5870    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5871    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5872    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5873    writeln!(
5874        code,
5875        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5876    )?;
5877    writeln!(
5878        code,
5879        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5880    )?;
5881    writeln!(
5882        code,
5883        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5884    )?;
5885    writeln!(
5886        code,
5887        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5888    )?;
5889    writeln!(
5890        code,
5891        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5892    )?;
5893    writeln!(
5894        code,
5895        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5896    )?;
5897    writeln!(
5898        code,
5899        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
5900    )?;
5901    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5902    writeln!(
5903        code,
5904        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
5905    )?;
5906    writeln!(
5907        code,
5908        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5909    )?;
5910    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5911    writeln!(code, "    }}")?;
5912    writeln!(code)?;
5913
5914    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
5915    writeln!(
5916        code,
5917        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
5918    )?;
5919    writeln!(
5920        code,
5921        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
5922    )?;
5923    writeln!(
5924        code,
5925        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
5926    )?;
5927    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5928    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5929    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
5930    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5931    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5932    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
5933    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5934    writeln!(
5935        code,
5936        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
5937    )?;
5938    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
5939    writeln!(
5940        code,
5941        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5942    )?;
5943    writeln!(
5944        code,
5945        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5946    )?;
5947    writeln!(
5948        code,
5949        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
5950    )?;
5951    writeln!(
5952        code,
5953        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5954    )?;
5955    writeln!(
5956        code,
5957        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5958    )?;
5959    writeln!(
5960        code,
5961        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5962    )?;
5963    writeln!(
5964        code,
5965        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5966    )?;
5967    writeln!(
5968        code,
5969        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
5970    )?;
5971    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5972    writeln!(
5973        code,
5974        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5975    )?;
5976    writeln!(
5977        code,
5978        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5979    )?;
5980    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5981    writeln!(code, "    }}")?;
5982    writeln!(code)?;
5983
5984    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
5985    writeln!(
5986        code,
5987        "    /// Dispatch fused K+V cache copy in one kernel launch."
5988    )?;
5989    writeln!(
5990        code,
5991        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
5992    )?;
5993    writeln!(
5994        code,
5995        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
5996    )?;
5997    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5998    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
5999    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6000    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6001    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
6002    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
6003    writeln!(
6004        code,
6005        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6006    )?;
6007    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6008    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
6009    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
6010    writeln!(
6011        code,
6012        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6013    )?;
6014    writeln!(
6015        code,
6016        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6017    )?;
6018    writeln!(
6019        code,
6020        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6021    )?;
6022    writeln!(
6023        code,
6024        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6025    )?;
6026    writeln!(
6027        code,
6028        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6029    )?;
6030    writeln!(
6031        code,
6032        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6033    )?;
6034    writeln!(
6035        code,
6036        "        let total = num_tokens * kv_dim * 2;  // K + V"
6037    )?;
6038    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6039    writeln!(
6040        code,
6041        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6042    )?;
6043    writeln!(
6044        code,
6045        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6046    )?;
6047    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6048    writeln!(code, "    }}")?;
6049
6050    writeln!(code, "}}")?;
6051    writeln!(code)?;
6052
6053    Ok(())
6054}
6055
6056fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6057    writeln!(
6058        code,
6059        "// ── Helper functions ──────────────────────────────────"
6060    )?;
6061    writeln!(code)?;
6062    writeln!(
6063        code,
6064        "/// Create a compute pipeline from a named function in the Metal library."
6065    )?;
6066    writeln!(
6067        code,
6068        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6069    )?;
6070    writeln!(
6071        code,
6072        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6073    )?;
6074    writeln!(
6075        code,
6076        "    device.new_compute_pipeline_state_with_function(&func)"
6077    )?;
6078    writeln!(
6079        code,
6080        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6081    )?;
6082    writeln!(code, "}}")?;
6083    writeln!(code)?;
6084
6085    Ok(())
6086}
6087
6088// ---------------------------------------------------------------------------
6089// main.rs generation
6090// ---------------------------------------------------------------------------
6091
6092fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6093    let _sanitized = sanitize_name(model_name);
6094    let _vocab = config.vocab_size;
6095
6096    let mut code = String::with_capacity(16 * 1024);
6097    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6098    writeln!(
6099        code,
6100        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6101    )?;
6102    writeln!(code)?;
6103    writeln!(code, "mod model;")?;
6104    writeln!(code)?;
6105    writeln!(code, "use std::io::Write;")?;
6106    writeln!(code, "use std::time::Instant;")?;
6107    writeln!(code, "use serde::Deserialize;")?;
6108    writeln!(code)?;
6109
6110    // -- main function --
6111    writeln!(code, "fn main() {{")?;
6112    writeln!(
6113        code,
6114        "    let args: Vec<String> = std::env::args().collect();"
6115    )?;
6116    writeln!(code)?;
6117    writeln!(
6118        code,
6119        "    // Detect --serve mode (only requires weights + tokenizer)"
6120    )?;
6121    writeln!(
6122        code,
6123        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
6124    )?;
6125    writeln!(code)?;
6126    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
6127    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6128    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6129    writeln!(code, "        std::process::exit(1);")?;
6130    writeln!(code, "    }}")?;
6131    writeln!(code)?;
6132    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
6133    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6134    writeln!(code, "        std::process::exit(1);")?;
6135    writeln!(code, "    }}")?;
6136    writeln!(code)?;
6137    writeln!(code, "    let weights_path = &args[1];")?;
6138    writeln!(code, "    let tokenizer_path = &args[2];")?;
6139    writeln!(code)?;
6140    writeln!(code, "    // Parse optional flags")?;
6141    writeln!(code, "    let mut max_tokens: usize = 128;")?;
6142    writeln!(code, "    let mut port: u16 = 8080;")?;
6143    writeln!(
6144        code,
6145        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6146    )?;
6147    writeln!(
6148        code,
6149        "    let profile = args.iter().any(|a| a == \"--profile\");"
6150    )?;
6151    writeln!(code, "    let mut i = 3;")?;
6152    writeln!(code, "    while i < args.len() {{")?;
6153    writeln!(
6154        code,
6155        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6156    )?;
6157    writeln!(
6158        code,
6159        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
6160    )?;
6161    writeln!(code, "            i += 2;")?;
6162    writeln!(
6163        code,
6164        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6165    )?;
6166    writeln!(
6167        code,
6168        "            port = args[i + 1].parse().unwrap_or(8080);"
6169    )?;
6170    writeln!(code, "            i += 2;")?;
6171    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
6172    writeln!(code, "            i += 1;")?;
6173    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
6174    writeln!(code, "            i += 1;")?;
6175    writeln!(code, "        }} else {{")?;
6176    writeln!(code, "            i += 1;")?;
6177    writeln!(code, "        }}")?;
6178    writeln!(code, "    }}")?;
6179    writeln!(code)?;
6180
6181    // -- load model (shared by both modes) --
6182    writeln!(
6183        code,
6184        "    // Memory-map weights for zero-copy loading on Apple Silicon"
6185    )?;
6186    writeln!(
6187        code,
6188        "    let weights_file = std::fs::File::open(weights_path)"
6189    )?;
6190    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6191    writeln!(
6192        code,
6193        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6194    )?;
6195    writeln!(code)?;
6196    writeln!(code, "    // Load tokenizer")?;
6197    writeln!(
6198        code,
6199        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6200    )?;
6201    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6202    writeln!(code)?;
6203    writeln!(code, "    // Create Metal model")?;
6204    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
6205    writeln!(
6206        code,
6207        "    let mut model = model::MetalModel::new(&weights_mmap);"
6208    )?;
6209    writeln!(code)?;
6210
6211    // -- branch: serve vs CLI --
6212    writeln!(code, "    if serve_mode {{")?;
6213    writeln!(code, "        serve(model, tokenizer, port);")?;
6214    writeln!(code, "    }} else {{")?;
6215    writeln!(code, "        let prompt = &args[3];")?;
6216    writeln!(
6217        code,
6218        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6219    )?;
6220    writeln!(code, "    }}")?;
6221    writeln!(code, "}}")?;
6222    writeln!(code)?;
6223
6224    // -- cli_mode function --
6225    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6226    writeln!(code, "    // Tokenize prompt")?;
6227    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
6228    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6229    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6230    writeln!(code)?;
6231    writeln!(
6232        code,
6233        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6234    )?;
6235    writeln!(
6236        code,
6237        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
6238    )?;
6239    writeln!(
6240        code,
6241        "    // The last token uses synchronous forward() to get logits."
6242    )?;
6243    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
6244    writeln!(code, "    let prefill_start = Instant::now();")?;
6245    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
6246    writeln!(
6247        code,
6248        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6249    )?;
6250    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
6251    writeln!(code, "    }} else {{")?;
6252    writeln!(code, "        model.forward(prompt_tokens[0])")?;
6253    writeln!(code, "    }};")?;
6254    writeln!(
6255        code,
6256        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6257    )?;
6258    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
6259    writeln!(
6260        code,
6261        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6262    )?;
6263    writeln!(
6264        code,
6265        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6266    )?;
6267    writeln!(code)?;
6268    writeln!(code, "    // Generate tokens")?;
6269    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6270    writeln!(code, "    let gen_start = Instant::now();")?;
6271    writeln!(code, "    let mut generated_count: usize = 0;")?;
6272    writeln!(code)?;
6273    writeln!(code, "    for _ in 0..max_tokens {{")?;
6274    writeln!(
6275        code,
6276        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6277    )?;
6278    writeln!(code, "            if !quiet {{")?;
6279    writeln!(code, "                print!(\"{{}}\", text);")?;
6280    writeln!(code, "                std::io::stdout().flush().ok();")?;
6281    writeln!(code, "            }}")?;
6282    writeln!(code, "        }}")?;
6283    writeln!(code, "        generated_count += 1;")?;
6284    writeln!(code)?;
6285    writeln!(
6286        code,
6287        "        // Use profiling forward for first token when --profile is set"
6288    )?;
6289    writeln!(
6290        code,
6291        "        let logits = if profile && generated_count == 1 {{"
6292    )?;
6293    writeln!(code, "            model.forward_profile(next_token)")?;
6294    writeln!(code, "        }} else {{")?;
6295    writeln!(code, "            model.forward(next_token)")?;
6296    writeln!(code, "        }};")?;
6297    writeln!(code, "        next_token = argmax(&logits);")?;
6298    writeln!(code)?;
6299    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
6300    writeln!(code, "        if next_token == 2 {{")?;
6301    writeln!(code, "            break;")?;
6302    writeln!(code, "        }}")?;
6303    writeln!(code)?;
6304    writeln!(
6305        code,
6306        "        // Yield between tokens to reduce sustained GPU thermal load."
6307    )?;
6308    writeln!(
6309        code,
6310        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6311    )?;
6312    writeln!(
6313        code,
6314        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6315    )?;
6316    writeln!(
6317        code,
6318        "        // briefly, providing a micro-break that helps sustain peak throughput."
6319    )?;
6320    writeln!(code, "        std::thread::yield_now();")?;
6321    writeln!(code, "    }}")?;
6322    writeln!(code, "    if !quiet {{")?;
6323    writeln!(code, "        println!();")?;
6324    writeln!(code, "    }}")?;
6325    writeln!(
6326        code,
6327        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6328    )?;
6329    writeln!(
6330        code,
6331        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6332    )?;
6333    writeln!(
6334        code,
6335        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6336    )?;
6337    writeln!(code, "}}")?;
6338    writeln!(code)?;
6339
6340    // -- argmax helper --
6341    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6342    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6343    writeln!(code, "    logits.iter()")?;
6344    writeln!(code, "        .enumerate()")?;
6345    writeln!(
6346        code,
6347        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6348    )?;
6349    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6350    writeln!(code, "        .unwrap_or(0)")?;
6351    writeln!(code, "}}")?;
6352    writeln!(code)?;
6353
6354    // -- Request/Response types for OpenAI API --
6355    writeln!(
6356        code,
6357        "// -----------------------------------------------------------------------"
6358    )?;
6359    writeln!(code, "// OpenAI-compatible API server")?;
6360    writeln!(
6361        code,
6362        "// -----------------------------------------------------------------------"
6363    )?;
6364    writeln!(code)?;
6365    writeln!(code, "#[derive(Deserialize)]")?;
6366    writeln!(code, "struct ChatRequest {{")?;
6367    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6368    writeln!(code, "    #[serde(default)]")?;
6369    writeln!(code, "    stream: Option<bool>,")?;
6370    writeln!(code, "    #[serde(default)]")?;
6371    writeln!(code, "    max_tokens: Option<usize>,")?;
6372    writeln!(code, "    #[serde(default)]")?;
6373    writeln!(code, "    temperature: Option<f32>,")?;
6374    writeln!(code, "    #[serde(default)]")?;
6375    writeln!(code, "    model: Option<String>,")?;
6376    writeln!(code, "}}")?;
6377    writeln!(code)?;
6378    writeln!(code, "#[derive(Deserialize)]")?;
6379    writeln!(code, "struct ChatMessage {{")?;
6380    writeln!(code, "    role: String,")?;
6381    writeln!(code, "    content: String,")?;
6382    writeln!(code, "}}")?;
6383    writeln!(code)?;
6384
6385    // -- format_chat_messages --
6386    writeln!(
6387        code,
6388        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6389    )?;
6390    writeln!(code, "    let mut prompt = String::new();")?;
6391    writeln!(code, "    for msg in messages {{")?;
6392    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6393    writeln!(code, "    }}")?;
6394    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6395    writeln!(code, "    prompt")?;
6396    writeln!(code, "}}")?;
6397    writeln!(code)?;
6398
6399    // -- prefill helper --
6400    writeln!(
6401        code,
6402        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6403    )?;
6404    writeln!(code, "    let len = tokens.len();")?;
6405    writeln!(code, "    if len > 1 {{")?;
6406    writeln!(
6407        code,
6408        "        model.forward_prefill_batch(&tokens[..len - 1]);"
6409    )?;
6410    writeln!(code, "    }}")?;
6411    writeln!(code, "    model.forward(tokens[len - 1])")?;
6412    writeln!(code, "}}")?;
6413    writeln!(code)?;
6414
6415    // -- serve function --
6416    writeln!(
6417        code,
6418        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6419    )?;
6420    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6421    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
6422    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6423    writeln!(
6424        code,
6425        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6426    )?;
6427    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
6428    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
6429    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
6430    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
6431    writeln!(code)?;
6432    writeln!(code, "    for request in server.incoming_requests() {{")?;
6433    writeln!(code, "        let method = request.method().to_string();")?;
6434    writeln!(code, "        let url = request.url().to_string();")?;
6435    writeln!(code)?;
6436    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
6437
6438    // -- POST /v1/chat/completions --
6439    writeln!(
6440        code,
6441        "            (\"POST\", \"/v1/chat/completions\") => {{"
6442    )?;
6443    writeln!(
6444        code,
6445        "                handle_chat_completion(&mut model, &tokenizer, request);"
6446    )?;
6447    writeln!(code, "            }}")?;
6448
6449    // -- GET /v1/models --
6450    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
6451    writeln!(code, "                let body = serde_json::json!({{")?;
6452    writeln!(code, "                    \"object\": \"list\",")?;
6453    writeln!(code, "                    \"data\": [{{")?;
6454    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
6455    writeln!(code, "                        \"object\": \"model\",")?;
6456    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
6457    writeln!(code, "                    }}]")?;
6458    writeln!(code, "                }});")?;
6459    writeln!(
6460        code,
6461        "                let resp = tiny_http::Response::from_string(body.to_string())"
6462    )?;
6463    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6464    writeln!(code, "                request.respond(resp).ok();")?;
6465    writeln!(code, "            }}")?;
6466
6467    // -- GET /health --
6468    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
6469    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6470    writeln!(code, "                request.respond(resp).ok();")?;
6471    writeln!(code, "            }}")?;
6472
6473    // -- 404 --
6474    writeln!(code, "            _ => {{")?;
6475    writeln!(
6476        code,
6477        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
6478    )?;
6479    writeln!(code, "                    .with_status_code(404);")?;
6480    writeln!(code, "                request.respond(resp).ok();")?;
6481    writeln!(code, "            }}")?;
6482    writeln!(code, "        }}")?;
6483    writeln!(code, "    }}")?;
6484    writeln!(code, "}}")?;
6485    writeln!(code)?;
6486
6487    // -- handle_chat_completion --
6488    writeln!(code, "fn handle_chat_completion(")?;
6489    writeln!(code, "    model: &mut model::MetalModel,")?;
6490    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6491    writeln!(code, "    mut request: tiny_http::Request,")?;
6492    writeln!(code, ") {{")?;
6493    writeln!(code, "    // Read request body")?;
6494    writeln!(code, "    let mut body = String::new();")?;
6495    writeln!(
6496        code,
6497        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6498    )?;
6499    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6500    writeln!(code, "            .with_status_code(400);")?;
6501    writeln!(code, "        request.respond(resp).ok();")?;
6502    writeln!(code, "        return;")?;
6503    writeln!(code, "    }}")?;
6504    writeln!(code)?;
6505    writeln!(code, "    // Parse JSON")?;
6506    writeln!(
6507        code,
6508        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6509    )?;
6510    writeln!(code, "        Ok(r) => r,")?;
6511    writeln!(code, "        Err(e) => {{")?;
6512    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6513    writeln!(code, "                .with_status_code(400);")?;
6514    writeln!(code, "            request.respond(resp).ok();")?;
6515    writeln!(code, "            return;")?;
6516    writeln!(code, "        }}")?;
6517    writeln!(code, "    }};")?;
6518    writeln!(code)?;
6519    writeln!(
6520        code,
6521        "    let prompt = format_chat_messages(&req.messages);"
6522    )?;
6523    writeln!(
6524        code,
6525        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6526    )?;
6527    writeln!(code, "        Ok(e) => e,")?;
6528    writeln!(code, "        Err(e) => {{")?;
6529    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6530    writeln!(code, "                .with_status_code(500);")?;
6531    writeln!(code, "            request.respond(resp).ok();")?;
6532    writeln!(code, "            return;")?;
6533    writeln!(code, "        }}")?;
6534    writeln!(code, "    }};")?;
6535    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6536    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6537    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6538    writeln!(
6539        code,
6540        "    let _temperature = req.temperature.unwrap_or(1.0);"
6541    )?;
6542    writeln!(code)?;
6543
6544    // -- Reset KV cache for each request --
6545    writeln!(code, "    model.reset();")?;
6546    writeln!(code)?;
6547
6548    // -- Prefill with timing --
6549    writeln!(code, "    let prefill_start = Instant::now();")?;
6550    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6551    writeln!(
6552        code,
6553        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6554    )?;
6555    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6556    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6557    writeln!(code)?;
6558
6559    writeln!(code, "    if stream {{")?;
6560
6561    // -- SSE streaming response --
6562    writeln!(
6563        code,
6564        "        // SSE streaming: generate tokens and build SSE body"
6565    )?;
6566    writeln!(code, "        let gen_start = Instant::now();")?;
6567    writeln!(code, "        let mut generated_count: usize = 0;")?;
6568    writeln!(code, "        let mut sse_body = String::new();")?;
6569    writeln!(code, "        for _ in 0..max_tokens {{")?;
6570    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6571    writeln!(
6572        code,
6573        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6574    )?;
6575    writeln!(
6576        code,
6577        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6578    )?;
6579    writeln!(
6580        code,
6581        "                // escaped includes surrounding quotes, strip them"
6582    )?;
6583    writeln!(
6584        code,
6585        "                let inner = &escaped[1..escaped.len()-1];"
6586    )?;
6587    writeln!(code, "                sse_body.push_str(&format!(")?;
6588    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6589    writeln!(code, "                    inner")?;
6590    writeln!(code, "                ));")?;
6591    writeln!(code, "            }}")?;
6592    writeln!(code, "            generated_count += 1;")?;
6593    writeln!(code, "            let logits = model.forward(next_token);")?;
6594    writeln!(code, "            next_token = argmax(&logits);")?;
6595    writeln!(code, "        }}")?;
6596    writeln!(
6597        code,
6598        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6599    )?;
6600    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6601    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
6602    writeln!(code)?;
6603    writeln!(
6604        code,
6605        "        // Final chunk with finish_reason, timing, and DONE sentinel"
6606    )?;
6607    writeln!(code, "        sse_body.push_str(&format!(")?;
6608    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6609    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6610    writeln!(code, "        ));")?;
6611    writeln!(code)?;
6612    writeln!(
6613        code,
6614        "        let resp = tiny_http::Response::from_string(sse_body)"
6615    )?;
6616    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6617    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6618    writeln!(code, "        request.respond(resp).ok();")?;
6619
6620    writeln!(code, "    }} else {{")?;
6621
6622    // -- Non-streaming response --
6623    writeln!(
6624        code,
6625        "        // Non-streaming: generate all tokens, return JSON"
6626    )?;
6627    writeln!(code, "        let gen_start = Instant::now();")?;
6628    writeln!(code, "        let mut generated_count: usize = 0;")?;
6629    writeln!(code, "        let mut generated = String::new();")?;
6630    writeln!(code, "        for _ in 0..max_tokens {{")?;
6631    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6632    writeln!(
6633        code,
6634        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6635    )?;
6636    writeln!(code, "                generated.push_str(&text);")?;
6637    writeln!(code, "            }}")?;
6638    writeln!(code, "            generated_count += 1;")?;
6639    writeln!(code, "            let logits = model.forward(next_token);")?;
6640    writeln!(code, "            next_token = argmax(&logits);")?;
6641    writeln!(code, "        }}")?;
6642    writeln!(
6643        code,
6644        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6645    )?;
6646    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6647    writeln!(code)?;
6648    writeln!(code, "        let resp_json = serde_json::json!({{")?;
6649    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
6650    writeln!(code, "            \"object\": \"chat.completion\",")?;
6651    writeln!(code, "            \"choices\": [{{")?;
6652    writeln!(code, "                \"index\": 0,")?;
6653    writeln!(code, "                \"message\": {{")?;
6654    writeln!(code, "                    \"role\": \"assistant\",")?;
6655    writeln!(code, "                    \"content\": generated")?;
6656    writeln!(code, "                }},")?;
6657    writeln!(code, "                \"finish_reason\": \"stop\"")?;
6658    writeln!(code, "            }}],")?;
6659    writeln!(code, "            \"usage\": {{")?;
6660    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
6661    writeln!(
6662        code,
6663        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
6664    )?;
6665    writeln!(
6666        code,
6667        "                \"generation_tokens\": generated_count,"
6668    )?;
6669    writeln!(
6670        code,
6671        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
6672    )?;
6673    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
6674    writeln!(code, "            }}")?;
6675    writeln!(code, "        }});")?;
6676    writeln!(
6677        code,
6678        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
6679    )?;
6680    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6681    writeln!(code, "        request.respond(resp).ok();")?;
6682    writeln!(code, "    }}")?;
6683    writeln!(code, "}}")?;
6684
6685    Ok(code)
6686}
6687
6688// ---------------------------------------------------------------------------
6689// Tests
6690// ---------------------------------------------------------------------------
6691
6692#[cfg(test)]
6693mod tests {
6694    use super::*;
6695    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
6696
6697    fn minimal_config() -> ModelConfig {
6698        ModelConfig {
6699            architecture: Architecture::Llama,
6700            hidden_size: 64,
6701            intermediate_size: 128,
6702            num_layers: 2,
6703            num_attention_heads: 4,
6704            num_kv_heads: 4,
6705            head_dim: 16,
6706            vocab_size: 256,
6707            max_seq_len: 512,
6708            rms_norm_eps: 1e-5,
6709            rope_theta: 10000.0,
6710            dtype: DType::F32,
6711            sliding_window_size: None,
6712            qkv_bias: false,
6713        }
6714    }
6715
6716    fn minimal_graph() -> Graph {
6717        Graph::new("test-metal").with_config(minimal_config())
6718    }
6719
6720    #[test]
6721    fn generate_metal_project_creates_files() {
6722        let dir = tempfile::tempdir().unwrap();
6723        let graph = minimal_graph();
6724        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
6725
6726        assert!(
6727            dir.path().join("Cargo.toml").exists(),
6728            "Cargo.toml should be created"
6729        );
6730        assert!(
6731            dir.path().join("src/model.rs").exists(),
6732            "src/model.rs should be created"
6733        );
6734        assert!(
6735            dir.path().join("src/main.rs").exists(),
6736            "src/main.rs should be created"
6737        );
6738        assert!(
6739            dir.path().join("shaders/kernels.metal").exists(),
6740            "shaders/kernels.metal should be created"
6741        );
6742    }
6743
6744    #[test]
6745    fn generated_cargo_toml_has_metal_dep() {
6746        let toml = generate_cargo_toml("my-model");
6747        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
6748        assert!(
6749            toml.contains("tokenizers"),
6750            "Cargo.toml should depend on tokenizers"
6751        );
6752        assert!(
6753            toml.contains("memmap2"),
6754            "Cargo.toml should depend on memmap2"
6755        );
6756        assert!(toml.contains("half"), "Cargo.toml should depend on half");
6757    }
6758
6759    #[test]
6760    fn generated_model_rs_contains_metal_code() {
6761        let config = minimal_config();
6762        let model_rs = generate_model_rs(&config).unwrap();
6763
6764        assert!(
6765            model_rs.contains("pub struct MetalModel"),
6766            "model.rs should define MetalModel struct"
6767        );
6768        assert!(
6769            model_rs.contains("matmul_pipeline: ComputePipelineState"),
6770            "MetalModel should have matmul_pipeline field"
6771        );
6772        assert!(
6773            model_rs.contains("Device::system_default()"),
6774            "model.rs should use Metal device"
6775        );
6776        assert!(
6777            model_rs.contains("new_library_with_source"),
6778            "model.rs should compile Metal shaders"
6779        );
6780        assert!(
6781            model_rs.contains("fn new(weights: &[u8])"),
6782            "MetalModel should implement new()"
6783        );
6784        assert!(
6785            model_rs.contains("fn forward(&mut self, token_id: u32)"),
6786            "MetalModel should implement forward()"
6787        );
6788    }
6789
6790    #[test]
6791    fn generated_shaders_contain_kernel_names() {
6792        let shaders = generate_metal_shaders(&minimal_config());
6793
6794        assert!(
6795            shaders.contains("kernel void matmul_vec"),
6796            "shaders should contain matmul_vec kernel"
6797        );
6798        assert!(
6799            shaders.contains("kernel void rms_norm"),
6800            "shaders should contain rms_norm kernel"
6801        );
6802        assert!(
6803            shaders.contains("kernel void rope"),
6804            "shaders should contain rope kernel"
6805        );
6806        assert!(
6807            shaders.contains("kernel void softmax"),
6808            "shaders should contain softmax kernel"
6809        );
6810        assert!(
6811            shaders.contains("kernel void silu_mul("),
6812            "shaders should contain silu_mul kernel"
6813        );
6814        assert!(
6815            shaders.contains("kernel void silu_mul_fused"),
6816            "shaders should contain silu_mul_fused kernel"
6817        );
6818        assert!(
6819            shaders.contains("kernel void elementwise_add"),
6820            "shaders should contain elementwise_add kernel"
6821        );
6822        assert!(
6823            shaders.contains("kernel void attention"),
6824            "shaders should contain attention kernel"
6825        );
6826        assert!(
6827            shaders.contains("kernel void add_inplace"),
6828            "shaders should contain add_inplace kernel"
6829        );
6830        assert!(
6831            shaders.contains("kernel void copy_buffer"),
6832            "shaders should contain copy_buffer kernel"
6833        );
6834        assert!(
6835            shaders.contains("kernel void copy_offset"),
6836            "shaders should contain copy_offset kernel"
6837        );
6838    }
6839
6840    #[test]
6841    fn generated_shaders_use_simdgroup_features() {
6842        let shaders = generate_metal_shaders(&minimal_config());
6843
6844        assert!(
6845            shaders.contains("threadgroup_barrier"),
6846            "shaders should use threadgroup barriers"
6847        );
6848        assert!(
6849            shaders.contains("threadgroup float"),
6850            "shaders should use threadgroup shared memory"
6851        );
6852        assert!(
6853            shaders.contains("thread_index_in_threadgroup"),
6854            "shaders should use threadgroup indexing"
6855        );
6856        assert!(
6857            shaders.contains("simd_sum"),
6858            "shaders should use simd_sum for warp-level reduction"
6859        );
6860        assert!(
6861            shaders.contains("simd_max"),
6862            "attention kernel should use simd_max for cooperative softmax"
6863        );
6864        assert!(
6865            shaders.contains("thread_index_in_simdgroup"),
6866            "shaders should use simdgroup lane indexing"
6867        );
6868        assert!(
6869            shaders.contains("simdgroup_index_in_threadgroup"),
6870            "shaders should use simdgroup indexing within threadgroup"
6871        );
6872        assert!(
6873            shaders.contains("float4"),
6874            "matmul_vec should use float4 vectorized loads"
6875        );
6876    }
6877
6878    #[test]
6879    fn generated_main_rs_has_tokenizer_usage() {
6880        let config = minimal_config();
6881        let main_rs = generate_main_rs("test-model", &config).unwrap();
6882
6883        assert!(
6884            main_rs.contains("tokenizers::Tokenizer"),
6885            "main.rs should use tokenizers crate"
6886        );
6887        assert!(
6888            main_rs.contains("MetalModel::new"),
6889            "main.rs should call MetalModel::new"
6890        );
6891        assert!(
6892            main_rs.contains("model.forward"),
6893            "main.rs should call model.forward"
6894        );
6895        assert!(
6896            main_rs.contains("memmap2"),
6897            "main.rs should use memmap2 for zero-copy weight loading"
6898        );
6899    }
6900
6901    #[test]
6902    fn missing_config_returns_error() {
6903        let dir = tempfile::tempdir().unwrap();
6904        let graph = Graph::new("no-config");
6905        let result = generate_metal_project(&graph, dir.path(), "fail");
6906        assert!(
6907            matches!(result, Err(MetalCodegenError::MissingConfig)),
6908            "should fail with MissingConfig when graph has no config"
6909        );
6910    }
6911
6912    #[test]
6913    fn sanitize_name_works() {
6914        assert_eq!(sanitize_name("My Model!"), "my-model");
6915        assert_eq!(sanitize_name("test_model"), "test-model");
6916        assert_eq!(sanitize_name("simple"), "simple");
6917    }
6918
6919    #[test]
6920    fn generated_forward_uses_single_command_buffer() {
6921        let config = minimal_config();
6922        let model_rs = generate_model_rs(&config).unwrap();
6923
6924        // The forward function should create exactly one command buffer.
6925        // Use the exact signature to avoid matching forward_prefill/forward_profile.
6926        let forward_start = model_rs
6927            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6928            .unwrap();
6929        let forward_body = &model_rs[forward_start..];
6930        // End at the next pub/private method
6931        let forward_end = forward_body
6932            .find("\n    pub fn forward_profile")
6933            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
6934            .or_else(|| forward_body.find("\n    fn dispatch_"))
6935            .unwrap_or(forward_body.len());
6936        let forward_code = &forward_body[..forward_end];
6937
6938        // Should have exactly one new_command_buffer call
6939        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
6940        assert_eq!(
6941            cmd_buf_count, 1,
6942            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
6943        );
6944
6945        // Should have exactly one commit call
6946        let commit_count = forward_code.matches("cmd.commit()").count();
6947        assert_eq!(
6948            commit_count, 1,
6949            "forward() should commit exactly once, found {commit_count}"
6950        );
6951
6952        // Should wait: once for cmd + possibly once for prev_cmd drain
6953        let wait_count = forward_code.matches("wait_until_completed()").count();
6954        assert!(
6955            wait_count >= 1 && wait_count <= 2,
6956            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
6957        );
6958    }
6959
6960    #[test]
6961    fn generated_model_has_preallocated_working_buffers() {
6962        let config = minimal_config();
6963        let model_rs = generate_model_rs(&config).unwrap();
6964
6965        for buf_name in &[
6966            "normed_buf",
6967            "qkv_buf",
6968            "attn_out_buf",
6969            "attn_proj_buf",
6970            "gate_up_buf",
6971            "ffn_hidden_buf",
6972            "ffn_out_buf",
6973            "add_tmp_buf",
6974        ] {
6975            assert!(
6976                model_rs.contains(&format!("{buf_name}: Buffer")),
6977                "MetalModel should have pre-allocated {buf_name} field"
6978            );
6979        }
6980    }
6981
6982    #[test]
6983    fn generated_dispatch_helpers_take_compute_encoder_ref() {
6984        let config = minimal_config();
6985        let model_rs = generate_model_rs(&config).unwrap();
6986
6987        for method in &[
6988            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
6989            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
6990            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
6991            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
6992            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
6993            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
6994            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
6995            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
6996            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
6997            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
6998            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
6999            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7000            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7001        ] {
7002            assert!(
7003                model_rs.contains(method),
7004                "model.rs should contain dispatch helper: {method}"
7005            );
7006        }
7007    }
7008
7009    #[test]
7010    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7011        let config = minimal_config();
7012        let model_rs = generate_model_rs(&config).unwrap();
7013
7014        // Find dispatch helpers section and check none create their own encoders
7015        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7016        let helpers_code = &model_rs[helpers_start..];
7017
7018        // None of the dispatch_ helpers should call new_command_buffer
7019        assert!(
7020            !helpers_code.contains("self.queue.new_command_buffer()"),
7021            "dispatch helpers should not create their own command buffers"
7022        );
7023
7024        // None should create their own compute encoders
7025        assert!(
7026            !helpers_code.contains("new_compute_command_encoder()"),
7027            "dispatch helpers should not create their own compute encoders"
7028        );
7029
7030        // None should call end_encoding
7031        assert!(
7032            !helpers_code.contains("end_encoding()"),
7033            "dispatch helpers should not call end_encoding"
7034        );
7035
7036        // None should call commit or wait
7037        assert!(
7038            !helpers_code.contains(".commit()"),
7039            "dispatch helpers should not commit command buffers"
7040        );
7041        assert!(
7042            !helpers_code.contains("wait_until_completed"),
7043            "dispatch helpers should not wait on command buffers"
7044        );
7045    }
7046
7047    #[test]
7048    fn generated_forward_batches_compute_encoders() {
7049        let config = minimal_config();
7050        let model_rs = generate_model_rs(&config).unwrap();
7051
7052        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
7053        let forward_start = model_rs
7054            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7055            .unwrap();
7056        let forward_body = &model_rs[forward_start..];
7057        let forward_end = forward_body
7058            .find("\n    pub fn forward_profile")
7059            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7060            .or_else(|| forward_body.find("\n    fn dispatch_"))
7061            .unwrap_or(forward_body.len());
7062        let forward_code = &forward_body[..forward_end];
7063
7064        // Forward should not allocate new buffers
7065        assert!(
7066            !forward_code.contains("device.new_buffer"),
7067            "forward() should not allocate new buffers per call"
7068        );
7069
7070        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
7071        // Copy operations use compute copy kernels instead of blit encoders.
7072        let compute_encoder_count = forward_code
7073            .matches("new_compute_command_encoder()")
7074            .count();
7075        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7076
7077        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
7078        assert_eq!(
7079            compute_encoder_count, 1,
7080            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7081        );
7082        assert_eq!(
7083            blit_encoder_count, 0,
7084            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7085        );
7086    }
7087
7088    #[test]
7089    fn generated_forward_uses_add_inplace() {
7090        let config = minimal_config();
7091        let model_rs = generate_model_rs(&config).unwrap();
7092
7093        // Should use in-place add (no blit copy-back needed)
7094        assert!(
7095            model_rs.contains("dispatch_add_inplace"),
7096            "forward() should use dispatch_add_inplace for residual connections"
7097        );
7098        assert!(
7099            model_rs.contains("add_inplace_pipeline"),
7100            "MetalModel should have add_inplace_pipeline"
7101        );
7102    }
7103
7104    fn minimal_q8_config() -> ModelConfig {
7105        ModelConfig {
7106            architecture: Architecture::Llama,
7107            hidden_size: 64,
7108            intermediate_size: 128,
7109            num_layers: 2,
7110            num_attention_heads: 4,
7111            num_kv_heads: 4,
7112            head_dim: 16,
7113            vocab_size: 256,
7114            max_seq_len: 512,
7115            rms_norm_eps: 1e-5,
7116            rope_theta: 10000.0,
7117            dtype: DType::Q8_0,
7118            sliding_window_size: None,
7119            qkv_bias: false,
7120        }
7121    }
7122
7123    #[test]
7124    fn generated_shaders_contain_q8_kernel() {
7125        let shaders = generate_metal_shaders(&minimal_config());
7126
7127        assert!(
7128            shaders.contains("kernel void matmul_vec_q8"),
7129            "shaders should contain matmul_vec_q8 kernel"
7130        );
7131        assert!(
7132            shaders.contains("device const uchar* matrix"),
7133            "matmul_vec_q8 should accept raw Q8_0 bytes"
7134        );
7135        assert!(
7136            shaders.contains("packed_short4"),
7137            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7138        );
7139        assert!(
7140            shaders.contains("as_type<char2>"),
7141            "matmul_vec_q8 should bitcast short lanes to char2"
7142        );
7143        assert!(
7144            shaders.contains("device const half*"),
7145            "matmul_vec_q8 should read f16 scale via half pointer"
7146        );
7147    }
7148
7149    #[test]
7150    fn generated_model_uses_fused_qkv_projections() {
7151        let config = minimal_config();
7152        let model_rs = generate_model_rs(&config).unwrap();
7153
7154        // Should have fused QKV weight in layer buffers
7155        assert!(
7156            model_rs.contains("qkv_weight: Buffer"),
7157            "LayerBuffers should have fused qkv_weight field"
7158        );
7159        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
7160        assert!(
7161            !model_rs.contains("    q_weight: Buffer"),
7162            "LayerBuffers should not have separate q_weight field"
7163        );
7164        assert!(
7165            !model_rs.contains("    k_weight: Buffer"),
7166            "LayerBuffers should not have separate k_weight field"
7167        );
7168        assert!(
7169            !model_rs.contains("    v_weight: Buffer"),
7170            "LayerBuffers should not have separate v_weight field"
7171        );
7172
7173        // Should have fused gate_up_weight
7174        assert!(
7175            model_rs.contains("gate_up_weight: Buffer"),
7176            "LayerBuffers should have fused gate_up_weight field"
7177        );
7178        // Should NOT have separate gate/up weight fields
7179        assert!(
7180            !model_rs.contains("    gate_weight: Buffer"),
7181            "LayerBuffers should not have separate gate_weight field"
7182        );
7183        assert!(
7184            !model_rs.contains("    up_weight: Buffer"),
7185            "LayerBuffers should not have separate up_weight field"
7186        );
7187
7188        // Should have fused working buffers
7189        assert!(
7190            model_rs.contains("qkv_buf: Buffer"),
7191            "MetalModel should have fused qkv_buf"
7192        );
7193        assert!(
7194            model_rs.contains("gate_up_buf: Buffer"),
7195            "MetalModel should have fused gate_up_buf"
7196        );
7197
7198        // Forward pass should use fused dispatch
7199        assert!(
7200            model_rs.contains("dispatch_silu_mul_fused"),
7201            "forward pass should use dispatch_silu_mul_fused"
7202        );
7203        assert!(
7204            model_rs.contains("dispatch_rope_offset"),
7205            "forward pass should use dispatch_rope_offset for fused QKV"
7206        );
7207        assert!(
7208            model_rs.contains("dispatch_attention_offset"),
7209            "forward pass should use dispatch_attention_offset for fused QKV"
7210        );
7211    }
7212
7213    #[test]
7214    fn q8_model_has_matmul_q8_pipeline() {
7215        let config = minimal_q8_config();
7216        let model_rs = generate_model_rs(&config).unwrap();
7217
7218        assert!(
7219            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7220            "MetalModel should have matmul_q8_pipeline field"
7221        );
7222        assert!(
7223            model_rs.contains("matmul_q8_pipeline,"),
7224            "MetalModel Self should include matmul_q8_pipeline"
7225        );
7226    }
7227
7228    #[test]
7229    fn q8_model_uses_dispatch_matmul_q8() {
7230        let config = minimal_q8_config();
7231        let model_rs = generate_model_rs(&config).unwrap();
7232
7233        assert!(
7234            model_rs.contains("dispatch_matmul_q8"),
7235            "Q8_0 model should use dispatch_matmul_q8 for projections"
7236        );
7237        assert!(
7238            model_rs.contains("fn dispatch_matmul_q8"),
7239            "model.rs should define dispatch_matmul_q8 method"
7240        );
7241    }
7242
7243    #[test]
7244    fn q8_model_loads_raw_bytes_not_dequantized() {
7245        let config = minimal_q8_config();
7246        let model_rs = generate_model_rs(&config).unwrap();
7247
7248        // Should NOT contain dequantization code
7249        assert!(
7250            !model_rs.contains("f16_to_f32"),
7251            "Q8_0 model should not dequantize weights to f32"
7252        );
7253        assert!(
7254            !model_rs.contains("f32_data"),
7255            "Q8_0 model should not create f32 weight data"
7256        );
7257
7258        // Should load raw Q8_0 bytes directly
7259        assert!(
7260            model_rs.contains("total_raw as u64"),
7261            "Q8_0 model should load raw bytes into Metal buffer"
7262        );
7263    }
7264
7265    #[test]
7266    fn q8_model_norms_stay_f32() {
7267        let config = minimal_q8_config();
7268        let model_rs = generate_model_rs(&config).unwrap();
7269
7270        // Norm weights should still use f32 buffers
7271        assert!(
7272            model_rs.contains("let attn_norm = next_f32_buffer"),
7273            "attn_norm should use f32 buffer even for Q8_0 models"
7274        );
7275        assert!(
7276            model_rs.contains("let ffn_norm = next_f32_buffer"),
7277            "ffn_norm should use f32 buffer even for Q8_0 models"
7278        );
7279        assert!(
7280            model_rs.contains("let norm_buf = next_f32_buffer"),
7281            "final norm should use f32 buffer even for Q8_0 models"
7282        );
7283    }
7284
7285    #[test]
7286    fn q8_model_uses_fused_weight_loading() {
7287        let config = minimal_q8_config();
7288        let model_rs = generate_model_rs(&config).unwrap();
7289
7290        // Should use fused Q8 buffer loading for QKV
7291        assert!(
7292            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7293            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7294        );
7295        // Should use fused Q8 buffer loading for gate+up
7296        assert!(
7297            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7298            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7299        );
7300        // Should still use regular q8 buffer for individual weights
7301        assert!(
7302            model_rs.contains("let o_weight = next_q8_buffer"),
7303            "Q8_0 model should use next_q8_buffer for O weight"
7304        );
7305        assert!(
7306            model_rs.contains("let down_weight = next_q8_buffer"),
7307            "Q8_0 model should use next_q8_buffer for down weight"
7308        );
7309    }
7310
7311    #[test]
7312    fn f32_model_does_not_use_q8_dispatch() {
7313        let config = minimal_config();
7314        let model_rs = generate_model_rs(&config).unwrap();
7315
7316        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
7317        let forward_start = model_rs
7318            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7319            .unwrap();
7320        let forward_body = &model_rs[forward_start..];
7321        let forward_end = forward_body
7322            .find("\n    fn dispatch_")
7323            .unwrap_or(forward_body.len());
7324        let forward_code = &forward_body[..forward_end];
7325
7326        assert!(
7327            !forward_code.contains("dispatch_matmul_q8"),
7328            "f32 model forward should not use dispatch_matmul_q8"
7329        );
7330    }
7331
7332    #[test]
7333    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7334        let config = minimal_q8_config();
7335        let model_rs = generate_model_rs(&config).unwrap();
7336
7337        assert!(
7338            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7339            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7340        );
7341    }
7342
7343    #[test]
7344    fn generated_model_has_double_buffered_prefill() {
7345        let config = minimal_config();
7346        let model_rs = generate_model_rs(&config).unwrap();
7347
7348        // MetalModel should have prev_cmd field for double-buffered prefill
7349        assert!(
7350            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7351            "MetalModel should have prev_cmd field for double-buffered prefill"
7352        );
7353
7354        // Should have forward_prefill method
7355        assert!(
7356            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7357            "MetalModel should have forward_prefill method"
7358        );
7359
7360        // forward() should drain prev_cmd at the start
7361        assert!(
7362            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7363            "forward() should drain prev_cmd from previous prefill"
7364        );
7365    }
7366
7367    #[test]
7368    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7369        let config = minimal_config();
7370        let main_rs = generate_main_rs("test-model", &config).unwrap();
7371
7372        assert!(
7373            main_rs.contains("forward_prefill"),
7374            "main.rs should use forward_prefill for intermediate prompt tokens"
7375        );
7376        assert!(
7377            main_rs.contains("double-buffered"),
7378            "main.rs should document double-buffered prefill"
7379        );
7380    }
7381
7382    #[test]
7383    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7384        let shaders = generate_metal_shaders(&minimal_config());
7385
7386        assert!(
7387            shaders.contains("packed_short4"),
7388            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7389        );
7390        assert!(
7391            shaders.contains("d0[0]"),
7392            "matmul_vec_q8 should index the wide pointer for row 0"
7393        );
7394        assert!(
7395            shaders.contains("as_type<char2>"),
7396            "matmul_vec_q8 should bitcast short lanes to char2"
7397        );
7398        assert!(
7399            shaders.contains("dot("),
7400            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7401        );
7402    }
7403
7404    // ── Q4_0 tests ──────────────────────────────────────────────────────
7405
7406    fn minimal_q4_config() -> ModelConfig {
7407        ModelConfig {
7408            architecture: Architecture::Llama,
7409            hidden_size: 64,
7410            intermediate_size: 128,
7411            num_layers: 2,
7412            num_attention_heads: 4,
7413            num_kv_heads: 4,
7414            head_dim: 16,
7415            vocab_size: 256,
7416            max_seq_len: 512,
7417            rms_norm_eps: 1e-5,
7418            rope_theta: 10000.0,
7419            dtype: DType::Q4_0,
7420            sliding_window_size: None,
7421            qkv_bias: false,
7422        }
7423    }
7424
7425    #[test]
7426    fn generated_shaders_contain_q4_kernel() {
7427        let shaders = generate_metal_shaders(&minimal_config());
7428
7429        assert!(
7430            shaders.contains("kernel void matmul_vec_q4"),
7431            "shaders should contain matmul_vec_q4 kernel"
7432        );
7433        assert!(
7434            shaders.contains("Q4_ROWS_PER_TG"),
7435            "shaders should define Q4_ROWS_PER_TG constant"
7436        );
7437        assert!(
7438            shaders.contains("Q4_ROWS_PER_SG"),
7439            "shaders should define Q4_ROWS_PER_SG constant"
7440        );
7441    }
7442
7443    #[test]
7444    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7445        let shaders = generate_metal_shaders(&minimal_config());
7446
7447        // Q4_0 kernel should use uchar4 for packed byte loads
7448        assert!(
7449            shaders.contains("uchar4"),
7450            "matmul_vec_q4 should use uchar4 for packed byte loads"
7451        );
7452        // Should unpack nibbles with &0xF and >>4
7453        assert!(
7454            shaders.contains("&0xF"),
7455            "matmul_vec_q4 should extract low nibble with &0xF"
7456        );
7457        assert!(
7458            shaders.contains(">>4"),
7459            "matmul_vec_q4 should extract high nibble with >>4"
7460        );
7461        // Should subtract 8 to convert unsigned to signed
7462        assert!(
7463            shaders.contains("-8)"),
7464            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7465        );
7466        // Should use 18-byte block size
7467        assert!(
7468            shaders.contains("blk * 18"),
7469            "matmul_vec_q4 should use 18-byte block stride"
7470        );
7471    }
7472
7473    #[test]
7474    fn q4_model_has_matmul_q4_pipeline() {
7475        let config = minimal_q4_config();
7476        let model_rs = generate_model_rs(&config).unwrap();
7477
7478        assert!(
7479            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7480            "MetalModel should have matmul_q4_pipeline field"
7481        );
7482        assert!(
7483            model_rs.contains("matmul_q4_pipeline,"),
7484            "MetalModel Self should include matmul_q4_pipeline"
7485        );
7486    }
7487
7488    #[test]
7489    fn q4_model_uses_dispatch_matmul_q4() {
7490        let config = minimal_q4_config();
7491        let model_rs = generate_model_rs(&config).unwrap();
7492
7493        assert!(
7494            model_rs.contains("dispatch_matmul_q4"),
7495            "Q4_0 model should use dispatch_matmul_q4 for projections"
7496        );
7497        assert!(
7498            model_rs.contains("fn dispatch_matmul_q4"),
7499            "model.rs should define dispatch_matmul_q4 method"
7500        );
7501    }
7502
7503    #[test]
7504    fn q4_model_loads_raw_bytes_not_dequantized() {
7505        let config = minimal_q4_config();
7506        let model_rs = generate_model_rs(&config).unwrap();
7507
7508        // Should NOT contain dequantization code
7509        assert!(
7510            !model_rs.contains("f16_to_f32"),
7511            "Q4_0 model should not dequantize weights to f32"
7512        );
7513
7514        // Should load raw Q4_0 bytes directly
7515        assert!(
7516            model_rs.contains("total_raw as u64"),
7517            "Q4_0 model should load raw bytes into Metal buffer"
7518        );
7519    }
7520
7521    #[test]
7522    fn q4_model_norms_stay_f32() {
7523        let config = minimal_q4_config();
7524        let model_rs = generate_model_rs(&config).unwrap();
7525
7526        assert!(
7527            model_rs.contains("let attn_norm = next_f32_buffer"),
7528            "attn_norm should use f32 buffer even for Q4_0 models"
7529        );
7530        assert!(
7531            model_rs.contains("let ffn_norm = next_f32_buffer"),
7532            "ffn_norm should use f32 buffer even for Q4_0 models"
7533        );
7534        assert!(
7535            model_rs.contains("let norm_buf = next_f32_buffer"),
7536            "final norm should use f32 buffer even for Q4_0 models"
7537        );
7538    }
7539
7540    #[test]
7541    fn q4_model_uses_fused_weight_loading() {
7542        let config = minimal_q4_config();
7543        let model_rs = generate_model_rs(&config).unwrap();
7544
7545        assert!(
7546            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7547            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7548        );
7549        assert!(
7550            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7551            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7552        );
7553        assert!(
7554            model_rs.contains("let o_weight = next_q4_buffer"),
7555            "Q4_0 model should use next_q4_buffer for O weight"
7556        );
7557        assert!(
7558            model_rs.contains("let down_weight = next_q4_buffer"),
7559            "Q4_0 model should use next_q4_buffer for down weight"
7560        );
7561    }
7562
7563    #[test]
7564    fn attention_flash_batch_kernel_wired() {
7565        // The legacy attention_batch kernel has a hardcoded scores[2048] cap
7566        // which silently overflows for base_pos+num_tokens > 2048.  Flash
7567        // attention kernel + tiered dispatch were added in v0.6.x to fix this.
7568        let config = minimal_config();
7569        let model_rs = generate_model_rs(&config).unwrap();
7570        let shaders = generate_metal_shaders(&config);
7571
7572        assert!(
7573            shaders.contains("kernel void attention_flash_batch"),
7574            "shaders.metal must contain the attention_flash_batch kernel"
7575        );
7576        assert!(
7577            shaders.contains("FLASH_K_TILE"),
7578            "flash kernel must tile K/V with a FLASH_K_TILE constant"
7579        );
7580        assert!(
7581            model_rs.contains("attention_flash_batch_pipeline"),
7582            "MetalModel must register the flash attention pipeline"
7583        );
7584        // Tiered dispatch — legacy below 2000, flash above.
7585        assert!(
7586            model_rs.contains("max_seq > 2000"),
7587            "dispatch_attention_batch must branch on max_seq > 2000"
7588        );
7589        assert!(
7590            model_rs.contains("&self.attention_flash_batch_pipeline"),
7591            "dispatch must reference the flash pipeline"
7592        );
7593    }
7594
7595    #[test]
7596    fn qwen2_qkv_bias_wired_through_metal_codegen() {
7597        // Issue #210: the pre-v0.6.2 Metal codegen had zero handling for
7598        // qkv_bias.  Verify that a Qwen2-style config emits the bias buffer,
7599        // loader, pipeline, and dispatch call in the expected places.
7600        let config = ModelConfig {
7601            architecture: Architecture::Qwen2,
7602            qkv_bias: true,
7603            ..minimal_config()
7604        };
7605        let model_rs = generate_model_rs(&config).unwrap();
7606
7607        assert!(
7608            model_rs.contains("qkv_bias: Buffer"),
7609            "Qwen2 LayerBuffers must declare qkv_bias field"
7610        );
7611        assert!(
7612            model_rs.contains("let qkv_bias = next_f32_buffer"),
7613            "Qwen2 layer init must load the bias from the weight blob"
7614        );
7615        assert!(
7616            model_rs.contains("add_bias_batch_pipeline"),
7617            "Qwen2 model struct must include the add_bias_batch_pipeline"
7618        );
7619        assert!(
7620            model_rs.contains("fn dispatch_add_bias_batch"),
7621            "Qwen2 codegen must emit dispatch_add_bias_batch helper"
7622        );
7623        assert!(
7624            model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
7625            "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
7626        );
7627        assert!(
7628            model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
7629            "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
7630        );
7631
7632        // The add_bias_batch MSL kernel must be in the shader source.
7633        let shaders = generate_metal_shaders(&config);
7634        assert!(
7635            shaders.contains("kernel void add_bias_batch"),
7636            "shaders.metal must contain the add_bias_batch kernel"
7637        );
7638    }
7639
7640    #[test]
7641    fn llama_does_not_emit_qkv_bias_machinery() {
7642        // Negative test: non-Qwen2 models must NOT carry the bias dispatch,
7643        // buffer, or pipeline — keeps generated code lean for Llama/Phi/etc.
7644        let config = minimal_config();
7645        assert!(!config.qkv_bias);
7646        let model_rs = generate_model_rs(&config).unwrap();
7647        assert!(
7648            !model_rs.contains("qkv_bias: Buffer"),
7649            "Llama must not have qkv_bias field"
7650        );
7651        assert!(
7652            !model_rs.contains("add_bias_batch_pipeline"),
7653            "Llama must not pull in add_bias_batch_pipeline"
7654        );
7655        assert!(
7656            !model_rs.contains("dispatch_add_bias_batch"),
7657            "Llama must not call dispatch_add_bias_batch"
7658        );
7659    }
7660
7661    #[test]
7662    fn q4_dispatch_helper_takes_compute_encoder_ref() {
7663        let config = minimal_q4_config();
7664        let model_rs = generate_model_rs(&config).unwrap();
7665
7666        assert!(
7667            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
7668            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
7669        );
7670    }
7671
7672    #[test]
7673    fn f32_model_does_not_use_q4_dispatch() {
7674        let config = minimal_config();
7675        let model_rs = generate_model_rs(&config).unwrap();
7676
7677        let forward_start = model_rs
7678            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7679            .unwrap();
7680        let forward_body = &model_rs[forward_start..];
7681        let forward_end = forward_body
7682            .find("\n    fn dispatch_")
7683            .unwrap_or(forward_body.len());
7684        let forward_code = &forward_body[..forward_end];
7685
7686        assert!(
7687            !forward_code.contains("dispatch_matmul_q4"),
7688            "f32 model forward should not use dispatch_matmul_q4"
7689        );
7690    }
7691
7692    #[test]
7693    fn q4_model_lm_head_uses_q4_buffer() {
7694        let config = minimal_q4_config();
7695        let model_rs = generate_model_rs(&config).unwrap();
7696
7697        assert!(
7698            model_rs.contains("let lm_head_buf = next_q4_buffer"),
7699            "Q4_0 model should use next_q4_buffer for lm_head"
7700        );
7701    }
7702
7703    #[test]
7704    fn vec_tile_size_matches_model_dimensions() {
7705        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
7706        let small = minimal_config();
7707        let shaders_small = generate_metal_shaders(&small);
7708        assert!(
7709            shaders_small.contains("vec_tile[128]"),
7710            "vec_tile should be sized to max(hidden, intermediate) = 128"
7711        );
7712
7713        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
7714        let mut large = minimal_config();
7715        large.hidden_size = 2048;
7716        large.intermediate_size = 8192;
7717        let shaders_large = generate_metal_shaders(&large);
7718        assert!(
7719            shaders_large.contains("vec_tile[8192]"),
7720            "vec_tile should be 8192 for models with intermediate=8192"
7721        );
7722        assert!(
7723            !shaders_large.contains("vec_tile[4096]"),
7724            "vec_tile should NOT be hardcoded to 4096"
7725        );
7726    }
7727
7728    #[test]
7729    fn generated_cargo_toml_has_server_deps() {
7730        let toml = generate_cargo_toml("my-model");
7731        assert!(
7732            toml.contains("tiny_http"),
7733            "Cargo.toml should depend on tiny_http"
7734        );
7735        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
7736        assert!(
7737            toml.contains("serde_json"),
7738            "Cargo.toml should depend on serde_json"
7739        );
7740    }
7741
7742    #[test]
7743    fn generated_main_rs_has_serve_mode() {
7744        let config = minimal_config();
7745        let main_rs = generate_main_rs("test-model", &config).unwrap();
7746
7747        assert!(
7748            main_rs.contains("--serve"),
7749            "main.rs should parse --serve flag"
7750        );
7751        assert!(
7752            main_rs.contains("--port"),
7753            "main.rs should parse --port flag"
7754        );
7755        assert!(
7756            main_rs.contains("fn serve("),
7757            "main.rs should define serve function"
7758        );
7759        assert!(
7760            main_rs.contains("tiny_http::Server::http"),
7761            "main.rs should create tiny_http server"
7762        );
7763    }
7764
7765    #[test]
7766    fn generated_main_rs_has_chat_completions_endpoint() {
7767        let config = minimal_config();
7768        let main_rs = generate_main_rs("test-model", &config).unwrap();
7769
7770        assert!(
7771            main_rs.contains("/v1/chat/completions"),
7772            "main.rs should handle /v1/chat/completions endpoint"
7773        );
7774        assert!(
7775            main_rs.contains("/v1/models"),
7776            "main.rs should handle /v1/models endpoint"
7777        );
7778        assert!(
7779            main_rs.contains("/health"),
7780            "main.rs should handle /health endpoint"
7781        );
7782    }
7783
7784    #[test]
7785    fn generated_main_rs_has_sse_streaming() {
7786        let config = minimal_config();
7787        let main_rs = generate_main_rs("test-model", &config).unwrap();
7788
7789        assert!(
7790            main_rs.contains("text/event-stream"),
7791            "main.rs should set SSE content type for streaming"
7792        );
7793        assert!(
7794            main_rs.contains("chat.completion.chunk"),
7795            "main.rs should emit SSE chunks"
7796        );
7797        assert!(
7798            main_rs.contains("[DONE]"),
7799            "main.rs should emit [DONE] sentinel"
7800        );
7801    }
7802
7803    #[test]
7804    fn generated_main_rs_has_chat_message_formatting() {
7805        let config = minimal_config();
7806        let main_rs = generate_main_rs("test-model", &config).unwrap();
7807
7808        assert!(
7809            main_rs.contains("fn format_chat_messages"),
7810            "main.rs should define format_chat_messages function"
7811        );
7812        assert!(
7813            main_rs.contains("<|im_start|>"),
7814            "main.rs should use ChatML format"
7815        );
7816        assert!(
7817            main_rs.contains("<|im_end|>"),
7818            "main.rs should use ChatML format"
7819        );
7820    }
7821
7822    #[test]
7823    fn generated_main_rs_has_request_types() {
7824        let config = minimal_config();
7825        let main_rs = generate_main_rs("test-model", &config).unwrap();
7826
7827        assert!(
7828            main_rs.contains("struct ChatRequest"),
7829            "main.rs should define ChatRequest struct"
7830        );
7831        assert!(
7832            main_rs.contains("struct ChatMessage"),
7833            "main.rs should define ChatMessage struct"
7834        );
7835        assert!(
7836            main_rs.contains("Deserialize"),
7837            "main.rs should derive Deserialize for request types"
7838        );
7839    }
7840
7841    #[test]
7842    fn generated_model_has_reset_method() {
7843        let config = minimal_config();
7844        let model_rs = generate_model_rs(&config).unwrap();
7845
7846        assert!(
7847            model_rs.contains("pub fn reset(&mut self)"),
7848            "model.rs should have a reset() method for multi-request serving"
7849        );
7850        assert!(
7851            model_rs.contains("self.pos = 0"),
7852            "reset() should reset position to 0"
7853        );
7854    }
7855
7856    #[test]
7857    fn generated_main_rs_cli_mode_still_works() {
7858        let config = minimal_config();
7859        let main_rs = generate_main_rs("test-model", &config).unwrap();
7860
7861        // CLI mode should still be functional
7862        assert!(
7863            main_rs.contains("fn cli_mode("),
7864            "main.rs should define cli_mode function"
7865        );
7866        assert!(
7867            main_rs.contains("model.forward"),
7868            "main.rs should call model.forward"
7869        );
7870        assert!(
7871            main_rs.contains("model.forward_prefill"),
7872            "main.rs should call model.forward_prefill"
7873        );
7874    }
7875
7876    // ── Batched prefill tests ──────────────────────────────────────────
7877
7878    #[test]
7879    fn generated_shaders_contain_batch_kernels() {
7880        let shaders = generate_metal_shaders(&minimal_config());
7881
7882        assert!(
7883            shaders.contains("kernel void matmul_vec_batch"),
7884            "shaders should contain matmul_vec_batch kernel"
7885        );
7886        assert!(
7887            shaders.contains("kernel void matmul_vec_q8_batch"),
7888            "shaders should contain matmul_vec_q8_batch kernel"
7889        );
7890        assert!(
7891            shaders.contains("kernel void matmul_q8_gemm_batch"),
7892            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
7893        );
7894        assert!(
7895            shaders.contains("kernel void matmul_vec_q4_batch"),
7896            "shaders should contain matmul_vec_q4_batch kernel"
7897        );
7898        assert!(
7899            shaders.contains("kernel void rms_norm_batch"),
7900            "shaders should contain rms_norm_batch kernel"
7901        );
7902        assert!(
7903            shaders.contains("kernel void silu_mul_fused_batch"),
7904            "shaders should contain silu_mul_fused_batch kernel"
7905        );
7906        assert!(
7907            shaders.contains("kernel void add_inplace_batch"),
7908            "shaders should contain add_inplace_batch kernel"
7909        );
7910        assert!(
7911            shaders.contains("kernel void copy_embedding_batch"),
7912            "shaders should contain copy_embedding_batch kernel"
7913        );
7914    }
7915
7916    #[test]
7917    fn generated_model_has_batch_pipelines() {
7918        let config = minimal_config();
7919        let model_rs = generate_model_rs(&config).unwrap();
7920
7921        for pipeline in &[
7922            "matmul_batch_pipeline",
7923            "matmul_q8_batch_pipeline",
7924            "matmul_q8_gemm_batch_pipeline",
7925            "matmul_q4_batch_pipeline",
7926            "rms_norm_batch_pipeline",
7927            "rope_batch_pipeline",
7928            "silu_mul_fused_batch_pipeline",
7929            "add_inplace_batch_pipeline",
7930            "copy_embedding_batch_pipeline",
7931            "attention_batch_pipeline",
7932            "copy_kv_batch_pipeline",
7933        ] {
7934            assert!(
7935                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
7936                "MetalModel should have {pipeline} field"
7937            );
7938        }
7939    }
7940
7941    #[test]
7942    fn generated_model_has_batch_buffers() {
7943        let config = minimal_config();
7944        let model_rs = generate_model_rs(&config).unwrap();
7945
7946        for buf in &[
7947            "batch_hidden_buf",
7948            "batch_residual_buf",
7949            "batch_qkv_buf",
7950            "batch_attn_out_buf",
7951            "batch_attn_proj_buf",
7952            "batch_gate_up_buf",
7953            "batch_ffn_hidden_buf",
7954            "batch_ffn_out_buf",
7955            "batch_tokens_buf",
7956            "batch_positions_buf",
7957        ] {
7958            assert!(
7959                model_rs.contains(&format!("{buf}: Buffer")),
7960                "MetalModel should have {buf} field"
7961            );
7962        }
7963    }
7964
7965    #[test]
7966    fn generated_model_has_forward_prefill_batch() {
7967        let config = minimal_config();
7968        let model_rs = generate_model_rs(&config).unwrap();
7969
7970        assert!(
7971            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
7972            "MetalModel should have forward_prefill_batch method"
7973        );
7974
7975        // forward_prefill should delegate to forward_prefill_batch
7976        assert!(
7977            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
7978            "forward_prefill should delegate to forward_prefill_batch"
7979        );
7980    }
7981
7982    #[test]
7983    fn generated_model_has_max_batch_size_constant() {
7984        let config = minimal_config();
7985        let model_rs = generate_model_rs(&config).unwrap();
7986
7987        assert!(
7988            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
7989            "model.rs should define MAX_BATCH_SIZE constant"
7990        );
7991    }
7992
7993    #[test]
7994    fn forward_prefill_batch_uses_batch_dispatch() {
7995        let config = minimal_config();
7996        let model_rs = generate_model_rs(&config).unwrap();
7997
7998        let batch_start = model_rs
7999            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8000            .unwrap();
8001        let batch_body = &model_rs[batch_start..];
8002        let batch_end = batch_body
8003            .find("\n    pub fn reset")
8004            .unwrap_or(batch_body.len());
8005        let batch_code = &batch_body[..batch_end];
8006
8007        // Should use batched dispatch methods
8008        assert!(
8009            batch_code.contains("dispatch_rms_norm_batch"),
8010            "forward_prefill_batch should use dispatch_rms_norm_batch"
8011        );
8012        assert!(
8013            batch_code.contains("dispatch_copy_embedding_batch"),
8014            "forward_prefill_batch should use dispatch_copy_embedding_batch"
8015        );
8016        assert!(
8017            batch_code.contains("dispatch_silu_mul_fused_batch"),
8018            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8019        );
8020        // Should use batched causal attention dispatch
8021        assert!(
8022            batch_code.contains("dispatch_attention_batch"),
8023            "forward_prefill_batch should use dispatch_attention_batch"
8024        );
8025        // Should use fused KV cache copy (both K and V in one dispatch)
8026        assert!(
8027            batch_code.contains("dispatch_copy_kv_both_batch"),
8028            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8029        );
8030        // Should use fused RoPE Q+K dispatch
8031        assert!(
8032            batch_code.contains("dispatch_rope_qk_batch"),
8033            "forward_prefill_batch should use dispatch_rope_qk_batch"
8034        );
8035    }
8036
8037    #[test]
8038    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8039        let config = minimal_q8_config();
8040        let model_rs = generate_model_rs(&config).unwrap();
8041
8042        let batch_start = model_rs
8043            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8044            .unwrap();
8045        let batch_body = &model_rs[batch_start..];
8046        let batch_end = batch_body
8047            .find("\n    pub fn reset")
8048            .unwrap_or(batch_body.len());
8049        let batch_code = &batch_body[..batch_end];
8050
8051        assert!(
8052            batch_code.contains("dispatch_matmul_q8_batch"),
8053            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8054        );
8055    }
8056
8057    #[test]
8058    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8059        let config = minimal_q4_config();
8060        let model_rs = generate_model_rs(&config).unwrap();
8061
8062        let batch_start = model_rs
8063            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8064            .unwrap();
8065        let batch_body = &model_rs[batch_start..];
8066        let batch_end = batch_body
8067            .find("\n    pub fn reset")
8068            .unwrap_or(batch_body.len());
8069        let batch_code = &batch_body[..batch_end];
8070
8071        assert!(
8072            batch_code.contains("dispatch_matmul_q4_batch"),
8073            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8074        );
8075    }
8076
8077    #[test]
8078    fn generated_main_rs_uses_batched_prefill() {
8079        let config = minimal_config();
8080        let main_rs = generate_main_rs("test-model", &config).unwrap();
8081
8082        assert!(
8083            main_rs.contains("forward_prefill_batch"),
8084            "main.rs should use forward_prefill_batch for prompt tokens"
8085        );
8086    }
8087
8088    #[test]
8089    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8090        let config = minimal_config();
8091        let model_rs = generate_model_rs(&config).unwrap();
8092
8093        let batch_start = model_rs
8094            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8095            .unwrap();
8096        let batch_body = &model_rs[batch_start..];
8097        let batch_end = batch_body
8098            .find("\n    pub fn reset")
8099            .unwrap_or(batch_body.len());
8100        let batch_code = &batch_body[..batch_end];
8101
8102        assert!(
8103            batch_code.contains("dispatch_matmul_batch"),
8104            "f32 forward_prefill_batch should use dispatch_matmul_batch"
8105        );
8106        // Should NOT use Q8 or Q4 batch dispatch
8107        assert!(
8108            !batch_code.contains("dispatch_matmul_q8_batch"),
8109            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8110        );
8111        assert!(
8112            !batch_code.contains("dispatch_matmul_q4_batch"),
8113            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8114        );
8115    }
8116
8117    #[test]
8118    fn forward_uses_cpu_embedding_lookup() {
8119        let config = minimal_config();
8120        let model_rs = generate_model_rs(&config).unwrap();
8121
8122        // Find just the forward() body (not forward_profile)
8123        let forward_start = model_rs
8124            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8125            .unwrap();
8126        let forward_body = &model_rs[forward_start..];
8127        let forward_end = forward_body
8128            .find("\n    pub fn forward_profile")
8129            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8130            .unwrap_or(forward_body.len());
8131        let forward_code = &forward_body[..forward_end];
8132
8133        // forward() should use CPU memcpy for embedding lookup (unified memory)
8134        assert!(
8135            forward_code.contains("embed_buf.contents()"),
8136            "forward() should access embed_buf via CPU unified memory for embedding lookup"
8137        );
8138        assert!(
8139            forward_code.contains("copy_nonoverlapping"),
8140            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8141        );
8142        // forward() should NOT use GPU dispatch for embedding
8143        assert!(
8144            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8145            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8146        );
8147    }
8148
8149    #[test]
8150    fn forward_profile_method_exists() {
8151        let config = minimal_config();
8152        let model_rs = generate_model_rs(&config).unwrap();
8153
8154        assert!(
8155            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8156            "MetalModel should have forward_profile() method"
8157        );
8158        // Profile method should print timing information
8159        assert!(
8160            model_rs.contains("[profile]"),
8161            "forward_profile() should print timing with [profile] prefix"
8162        );
8163        assert!(
8164            model_rs.contains("d_embed"),
8165            "forward_profile() should measure embedding time"
8166        );
8167        assert!(
8168            model_rs.contains("d_layers"),
8169            "forward_profile() should measure layer time"
8170        );
8171        assert!(
8172            model_rs.contains("d_logits"),
8173            "forward_profile() should measure logits time"
8174        );
8175    }
8176
8177    #[test]
8178    fn generated_cli_has_profile_flag() {
8179        let config = minimal_config();
8180        let main_rs = generate_main_rs("test-model", &config).unwrap();
8181
8182        assert!(
8183            main_rs.contains("--profile"),
8184            "CLI should support --profile flag"
8185        );
8186        assert!(
8187            main_rs.contains("forward_profile"),
8188            "CLI should call forward_profile when --profile is set"
8189        );
8190    }
8191
8192    #[test]
8193    fn generated_cli_has_thermal_yield() {
8194        let config = minimal_config();
8195        let main_rs = generate_main_rs("test-model", &config).unwrap();
8196
8197        assert!(
8198            main_rs.contains("yield_now()"),
8199            "CLI generation loop should include thread::yield_now() for thermal management"
8200        );
8201    }
8202
8203    // ── Real-world validation tests ──────────────────────────────────────
8204
8205    #[test]
8206    fn generated_forward_handles_single_token_prompt() {
8207        // With a single token (the first prompt token), forward() should work
8208        // at pos=0 where seq_len=1. The attention kernel must handle the case
8209        // where there is only one KV entry (no prefill context).
8210        let config = minimal_config();
8211        let model_rs = generate_model_rs(&config).unwrap();
8212
8213        // The forward function should accept any u32 token_id (no minimum pos guard)
8214        let forward_start = model_rs
8215            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8216            .expect("forward() must exist");
8217        let forward_body = &model_rs[forward_start..forward_start + 400];
8218
8219        // Should NOT require pos > 0 or seq_len > 1
8220        assert!(
8221            !forward_body.contains("assert!(self.pos > 0"),
8222            "forward() must accept pos=0 (first token with no prefill)"
8223        );
8224
8225        // The attention kernel should handle seq_len=1 via the pos field
8226        assert!(
8227            model_rs.contains("self.pos"),
8228            "forward() should use self.pos to track sequence position"
8229        );
8230    }
8231
8232    #[test]
8233    fn generated_reset_clears_kv_cache_position() {
8234        // After reset(), the model should be in a clean state. The pos field
8235        // must be 0 so new generation starts from scratch.
8236        let config = minimal_config();
8237        let model_rs = generate_model_rs(&config).unwrap();
8238
8239        let reset_start = model_rs
8240            .find("pub fn reset(&mut self)")
8241            .expect("reset() must exist");
8242        let reset_body = &model_rs[reset_start..reset_start + 200];
8243
8244        // Reset must zero the position counter
8245        assert!(
8246            reset_body.contains("self.pos = 0"),
8247            "reset() must set self.pos = 0"
8248        );
8249
8250        // Verify reset clears prev_cmd (double-buffering state)
8251        assert!(
8252            reset_body.contains("self.prev_cmd = None"),
8253            "reset() should clear prev_cmd for clean command buffer state"
8254        );
8255    }
8256
8257    #[test]
8258    fn generated_serve_handles_empty_messages_gracefully() {
8259        // The serve endpoint should not crash when receiving an empty messages array.
8260        // The format_chat_messages function should handle this gracefully.
8261        let config = minimal_config();
8262        let main_rs = generate_main_rs("test-model", &config).unwrap();
8263
8264        // The format_chat_messages function should exist and handle empty input
8265        let format_fn_start = main_rs
8266            .find("fn format_chat_messages")
8267            .expect("format_chat_messages must exist");
8268        let format_fn_body =
8269            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8270
8271        // It should iterate over messages (an empty slice produces an empty loop)
8272        assert!(
8273            format_fn_body.contains("for msg in messages"),
8274            "format_chat_messages should iterate over the messages slice"
8275        );
8276        // It should always append the assistant prompt suffix
8277        assert!(
8278            format_fn_body.contains("<|im_start|>assistant"),
8279            "format_chat_messages should always append assistant prompt header"
8280        );
8281
8282        // The serve function should call model.reset() before each request
8283        let serve_fn_start = main_rs
8284            .find("fn serve(")
8285            .expect("serve function must exist");
8286        let serve_fn_body = &main_rs[serve_fn_start..];
8287        assert!(
8288            serve_fn_body.contains("model.reset()"),
8289            "serve function should reset model between requests"
8290        );
8291    }
8292
8293    #[test]
8294    fn generated_model_forward_increments_pos() {
8295        // Each forward() call must increment self.pos so the next token
8296        // uses the correct RoPE position and KV cache offset.
8297        let config = minimal_config();
8298        let model_rs = generate_model_rs(&config).unwrap();
8299
8300        let forward_start = model_rs
8301            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8302            .unwrap();
8303        let forward_body = &model_rs[forward_start..];
8304        let forward_end = forward_body
8305            .find("\n    pub fn forward_profile")
8306            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8307            .or_else(|| forward_body.find("\n    fn dispatch_"))
8308            .unwrap_or(forward_body.len());
8309        let forward_code = &forward_body[..forward_end];
8310
8311        assert!(
8312            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8313            "forward() must increment self.pos after processing a token"
8314        );
8315    }
8316}