Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    // The attention-scores shared array must be at least effective_seq_len
122    // elements; anything smaller silently overflows for long prompts.  For
123    // small-context models (135M at 2K), a smaller array saves threadgroup
124    // memory and improves occupancy — so we size it precisely.
125    let attn_scores_size = config.max_seq_len.min(4096);
126    r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154    device const float* matrix [[buffer(0)]],
155    device const float* vector [[buffer(1)]],
156    device float* output       [[buffer(2)]],
157    constant uint& rows        [[buffer(3)]],
158    constant uint& cols        [[buffer(4)]],
159    uint tgid [[threadgroup_position_in_grid]],
160    uint tid [[thread_index_in_threadgroup]],
161    uint simd_lane [[thread_index_in_simdgroup]],
162    uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164    // Cooperatively load vector into threadgroup shared memory
165    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166    for (uint i = tid; i < cols; i += 256) {
167        vec_tile[i] = vector[i];
168    }
169    threadgroup_barrier(mem_flags::mem_threadgroup);
170
171    // Each simdgroup handles 8 consecutive rows
172    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173    if (row_base >= rows) return;
174
175    uint base0 = row_base * cols;
176    uint base1 = (row_base + 1) * cols;
177    uint base2 = (row_base + 2) * cols;
178    uint base3 = (row_base + 3) * cols;
179    uint base4 = (row_base + 4) * cols;
180    uint base5 = (row_base + 5) * cols;
181    uint base6 = (row_base + 6) * cols;
182    uint base7 = (row_base + 7) * cols;
183
184    // float4 vectorized accumulation across 8 rows
185    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
186    float4 sum4_0 = float4(0.0f);
187    float4 sum4_1 = float4(0.0f);
188    float4 sum4_2 = float4(0.0f);
189    float4 sum4_3 = float4(0.0f);
190    float4 sum4_4 = float4(0.0f);
191    float4 sum4_5 = float4(0.0f);
192    float4 sum4_6 = float4(0.0f);
193    float4 sum4_7 = float4(0.0f);
194
195    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205    }
206
207    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216    // Handle remaining elements (cols not divisible by 128)
217    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218        float vv = vec_tile[j];
219        sum0 += matrix[base0 + j] * vv;
220        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227    }
228
229    // Simdgroup hardware warp-level reduction
230    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235    // Only first lane writes the results
236    if (simd_lane == 0) {
237        if (row_base     < rows) output[row_base]     = sum0;
238        if (row_base + 1 < rows) output[row_base + 1] = sum1;
239        if (row_base + 2 < rows) output[row_base + 2] = sum2;
240        if (row_base + 3 < rows) output[row_base + 3] = sum3;
241        if (row_base + 4 < rows) output[row_base + 4] = sum4;
242        if (row_base + 5 < rows) output[row_base + 5] = sum5;
243        if (row_base + 6 < rows) output[row_base + 6] = sum6;
244        if (row_base + 7 < rows) output[row_base + 7] = sum7;
245    }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253    device const float* input   [[buffer(0)]],
254    device const float* weight  [[buffer(1)]],
255    device float* output        [[buffer(2)]],
256    constant uint& n            [[buffer(3)]],
257    constant float& eps         [[buffer(4)]],
258    uint tid [[thread_index_in_threadgroup]])
259{
260    // Each thread accumulates partial sum-of-squares
261    float sum_sq = 0.0f;
262    for (uint i = tid; i < n; i += 256) {
263        float v = input[i];
264        sum_sq += v * v;
265    }
266
267    // Simdgroup-level reduction (hardware warp sum)
268    sum_sq = simd_sum(sum_sq);
269
270    // Cross-simdgroup reduction via shared memory
271    threadgroup float shared[8];
272    uint simd_id = tid / 32;
273    uint simd_lane = tid % 32;
274    if (simd_lane == 0) {
275        shared[simd_id] = sum_sq;
276    }
277    threadgroup_barrier(mem_flags::mem_threadgroup);
278
279    // First thread computes final inverse RMS
280    if (tid == 0) {
281        float total = 0.0f;
282        for (uint i = 0; i < 8; i++) {
283            total += shared[i];
284        }
285        shared[0] = fast::rsqrt(total / float(n) + eps);
286    }
287    threadgroup_barrier(mem_flags::mem_threadgroup);
288
289    float inv_rms = shared[0];
290
291    // Normalize
292    for (uint i = tid; i < n; i += 256) {
293        output[i] = input[i] * inv_rms * weight[i];
294    }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301    device float* data        [[buffer(0)]],
302    constant uint& num_heads  [[buffer(1)]],
303    constant uint& head_dim   [[buffer(2)]],
304    constant uint& pos        [[buffer(3)]],
305    constant float& theta     [[buffer(4)]],
306    uint id [[thread_position_in_grid]])
307{
308    uint half_dim = head_dim / 2;
309    uint total_pairs = num_heads * half_dim;
310    if (id >= total_pairs) return;
311
312    uint h = id / half_dim;
313    uint i = id % half_dim;
314    uint off = h * head_dim;
315
316    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317    float angle = float(pos) * freq;
318    float c = cos(angle);
319    float s = sin(angle);
320
321    float x0 = data[off + 2 * i];
322    float x1 = data[off + 2 * i + 1];
323    data[off + 2 * i]     = x0 * c - x1 * s;
324    data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331    device float* data       [[buffer(0)]],
332    constant uint& n         [[buffer(1)]],
333    uint tid [[thread_index_in_threadgroup]],
334    uint tg_size [[threads_per_threadgroup]])
335{
336    threadgroup float shared_val[256];
337
338    // Pass 1: find max
339    float local_max = -INFINITY;
340    for (uint i = tid; i < n; i += tg_size) {
341        local_max = max(local_max, data[i]);
342    }
343    shared_val[tid] = local_max;
344    threadgroup_barrier(mem_flags::mem_threadgroup);
345
346    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347        if (tid < stride) {
348            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349        }
350        threadgroup_barrier(mem_flags::mem_threadgroup);
351    }
352    float max_val = shared_val[0];
353    threadgroup_barrier(mem_flags::mem_threadgroup);
354
355    // Pass 2: exp and sum
356    float local_sum = 0.0f;
357    for (uint i = tid; i < n; i += tg_size) {
358        float e = fast::exp(data[i] - max_val);
359        data[i] = e;
360        local_sum += e;
361    }
362    shared_val[tid] = local_sum;
363    threadgroup_barrier(mem_flags::mem_threadgroup);
364
365    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366        if (tid < stride) {
367            shared_val[tid] += shared_val[tid + stride];
368        }
369        threadgroup_barrier(mem_flags::mem_threadgroup);
370    }
371    float inv_sum = 1.0f / shared_val[0];
372    threadgroup_barrier(mem_flags::mem_threadgroup);
373
374    // Pass 3: normalize
375    for (uint i = tid; i < n; i += tg_size) {
376        data[i] *= inv_sum;
377    }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384    device const float* gate [[buffer(0)]],
385    device const float* up   [[buffer(1)]],
386    device float* output     [[buffer(2)]],
387    constant uint& n         [[buffer(3)]],
388    uint id [[thread_position_in_grid]])
389{
390    if (id >= n) return;
391    float g = gate[id];
392    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397//   gate = gate_up[0..n], up = gate_up[n..2*n]
398//   output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400    device const float* gate_up [[buffer(0)]],
401    device float* output        [[buffer(1)]],
402    constant uint& n            [[buffer(2)]],
403    uint id [[thread_position_in_grid]])
404{
405    if (id >= n) return;
406    float g = gate_up[id];
407    float u = gate_up[n + id];
408    output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414    device const float* a  [[buffer(0)]],
415    device const float* b  [[buffer(1)]],
416    device float* output   [[buffer(2)]],
417    constant uint& n       [[buffer(3)]],
418    uint id [[thread_position_in_grid]])
419{
420    if (id >= n) return;
421    output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428    device const float* src [[buffer(0)]],
429    device float* dst       [[buffer(1)]],
430    constant uint& count    [[buffer(2)]],
431    uint id [[thread_position_in_grid]])
432{
433    if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440    device const float* src     [[buffer(0)]],
441    device float* dst           [[buffer(1)]],
442    constant uint& src_offset   [[buffer(2)]],  // in floats
443    constant uint& count        [[buffer(3)]],
444    uint id [[thread_position_in_grid]])
445{
446    if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── add_inplace ─────────────────────────────────────────────────────────
450// In-place residual connection: a[i] += b[i]
451// Avoids a separate blit copy for residual add, reducing encoder overhead.
452kernel void add_inplace(
453    device float* a        [[buffer(0)]],
454    device const float* b  [[buffer(1)]],
455    constant uint& n       [[buffer(2)]],
456    uint id [[thread_position_in_grid]])
457{
458    if (id >= n) return;
459    a[id] += b[id];
460}
461
462// ── matmul_vec_q8 ─────────────────────────────────────────────────────
463// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
464// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
465// Operates directly on quantized weights to halve memory bandwidth vs f32,
466// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
467//
468// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
469// because int8->float conversion doubles register demand.  Fully unrolled
470// inner loop with float4 vector loads from shared memory eliminates loop
471// overhead and enables better instruction scheduling.
472// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
473constant constexpr uint Q8_ROWS_PER_SG = 4;
474constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
475
476// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
477// doubles ALU work, so the same register budget applies).
478constant constexpr uint Q4_ROWS_PER_SG = 4;
479constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
480
481kernel void matmul_vec_q8(
482    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
483    device const float* vector   [[buffer(1)]],  // f32 input
484    device float* output         [[buffer(2)]],
485    constant uint& rows          [[buffer(3)]],
486    constant uint& cols          [[buffer(4)]],  // number of elements per row
487    uint tgid [[threadgroup_position_in_grid]],
488    uint tid [[thread_index_in_threadgroup]],
489    uint simd_lane [[thread_index_in_simdgroup]],
490    uint simd_id [[simdgroup_index_in_threadgroup]])
491{
492    // Load vector into shared memory
493    threadgroup float vec_tile[VEC_TILE_SIZE];
494    for (uint i = tid; i < cols; i += 256) {
495        vec_tile[i] = vector[i];
496    }
497    threadgroup_barrier(mem_flags::mem_threadgroup);
498
499    // Each simdgroup handles 4 consecutive rows (lower register pressure)
500    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
501    if (row_base >= rows) return;
502
503    // Q8_0: each block is 34 bytes for 32 elements
504    uint blocks_per_row = cols / 32;
505    uint row_bytes = blocks_per_row * 34;
506
507    // Pointers to each row's Q8_0 data
508    device const uchar* r0 = matrix + row_base * row_bytes;
509    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
510    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
511    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
512
513    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
514
515    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
516        uint bb = blk * 34;
517        uint vb = blk * 32;
518
519        // Prefetch all 4 scales
520        float sc0 = float(*(device const half*)(r0 + bb));
521        float sc1 = float(*(device const half*)(r1 + bb));
522        float sc2 = float(*(device const half*)(r2 + bb));
523        float sc3 = float(*(device const half*)(r3 + bb));
524
525        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
526        // Q8_0 block layout where the int8 data starts at offset +2 from a
527        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
528        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
529        // reduction in memory transactions. Metal's char16/packed_char16 are
530        // reserved types and packed_*int4 require >=4-byte alignment which
531        // this layout does not provide, so packed_short4 is the widest valid
532        // vectorized load.
533        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
534        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
535        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
536        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
537
538        // Load all 8 float4 vector values for this 32-element block from shared memory
539        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
540        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
541        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
542        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
543        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
544        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
545        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
546        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
547
548        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
549        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
550        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
551            short4 _s = short4(SHORT4); \
552            char2 _a = as_type<char2>(_s.x); \
553            char2 _b = as_type<char2>(_s.y); \
554            char2 _c = as_type<char2>(_s.z); \
555            char2 _d = as_type<char2>(_s.w); \
556            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
557            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
558        }
559
560        float4 f0, f1;
561        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
562
563        // Row 0: 4 short4 loads cover 32 int8 weights
564        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
565        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
566        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
567        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
568
569        // Row 1
570        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
571        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
572        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
573        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
574
575        // Row 2
576        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
577        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
578        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
579        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
580
581        // Row 3
582        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
583        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
584        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
585        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
586
587        #undef Q8_UNPACK8
588
589        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
590    }
591
592    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
593    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
594
595    if (simd_lane == 0) {
596        if (row_base     < rows) output[row_base]     = sum0;
597        if (row_base + 1 < rows) output[row_base + 1] = sum1;
598        if (row_base + 2 < rows) output[row_base + 2] = sum2;
599        if (row_base + 3 < rows) output[row_base + 3] = sum3;
600    }
601}
602
603// ── matmul_vec_q4 ─────────────────────────────────────────────────────
604// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
605// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
606// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
607// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
608//
609// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
610// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
611kernel void matmul_vec_q4(
612    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
613    device const float* vector   [[buffer(1)]],  // f32 input
614    device float* output         [[buffer(2)]],
615    constant uint& rows          [[buffer(3)]],
616    constant uint& cols          [[buffer(4)]],  // number of elements per row
617    uint tgid [[threadgroup_position_in_grid]],
618    uint tid [[thread_index_in_threadgroup]],
619    uint simd_lane [[thread_index_in_simdgroup]],
620    uint simd_id [[simdgroup_index_in_threadgroup]])
621{
622    // Load vector into shared memory
623    threadgroup float vec_tile[VEC_TILE_SIZE];
624    for (uint i = tid; i < cols; i += 256) {
625        vec_tile[i] = vector[i];
626    }
627    threadgroup_barrier(mem_flags::mem_threadgroup);
628
629    // Each simdgroup handles 4 consecutive rows
630    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
631    if (row_base >= rows) return;
632
633    // Q4_0: each block is 18 bytes for 32 elements
634    uint blocks_per_row = cols / 32;
635    uint row_bytes = blocks_per_row * 18;
636
637    // Pointers to each row's Q4_0 data
638    device const uchar* r0 = matrix + row_base * row_bytes;
639    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
640    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
641    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
642
643    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
644
645    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
646        uint bb = blk * 18;
647        uint vb = blk * 32;
648
649        // Prefetch all 4 scales
650        float sc0 = float(*(device const half*)(r0 + bb));
651        float sc1 = float(*(device const half*)(r1 + bb));
652        float sc2 = float(*(device const half*)(r2 + bb));
653        float sc3 = float(*(device const half*)(r3 + bb));
654
655        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
656        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
657        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
658        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
659        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
660
661        // Load 8 float4 vector values for 32 elements from shared memory
662        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
663        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
664        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
665        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
666        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
667        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
668        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
669        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
670        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
671
672        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
673        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
674        float bd0=0, bd1=0, bd2=0, bd3=0;
675        uchar4 b;
676
677        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
678        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
679                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
680                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
681                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
682        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
683                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
684                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
685                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
686        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
687                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
688                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
689                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
690        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
691                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
692                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
693                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
694
695        // Row 1
696        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
697                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
698                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
699                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
700        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
701                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
702                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
703                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
704        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
705                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
706                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
707                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
708        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
709                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
710                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
711                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
712
713        // Row 2
714        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
715                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
716                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
717                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
718        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
719                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
720                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
721                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
722        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
723                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
724                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
725                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
726        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
727                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
728                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
729                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
730
731        // Row 3
732        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
733                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
734                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
735                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
736        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
737                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
738                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
739                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
740        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
741                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
742                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
743                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
744        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
745                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
746                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
747                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
748
749        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
750    }
751
752    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
753    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
754
755    if (simd_lane == 0) {
756        if (row_base     < rows) output[row_base]     = sum0;
757        if (row_base + 1 < rows) output[row_base + 1] = sum1;
758        if (row_base + 2 < rows) output[row_base + 2] = sum2;
759        if (row_base + 3 < rows) output[row_base + 3] = sum3;
760    }
761}
762
763// ── attention ───────────────────────────────────────────────────────────
764// Single-query attention with simdgroup cooperative reductions.
765// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
766// with simd_max/simd_sum reductions, then weighted sum of V.
767// Each threadgroup handles one head with 256 threads (8 simdgroups).
768//
769// Buffers:
770//   q:       [num_heads * head_dim]       current query
771//   k_cache: [max_seq_len * num_kv_heads * head_dim]
772//   v_cache: [max_seq_len * num_kv_heads * head_dim]
773//   output:  [num_heads * head_dim]
774kernel void attention(
775    device const float* q        [[buffer(0)]],
776    device const float* k_cache  [[buffer(1)]],
777    device const float* v_cache  [[buffer(2)]],
778    device float* output         [[buffer(3)]],
779    constant uint& seq_len       [[buffer(4)]],
780    constant uint& num_heads     [[buffer(5)]],
781    constant uint& num_kv_heads  [[buffer(6)]],
782    constant uint& head_dim      [[buffer(7)]],
783    uint tgid [[threadgroup_position_in_grid]],
784    uint tid [[thread_index_in_threadgroup]],
785    uint simd_lane [[thread_index_in_simdgroup]],
786    uint simd_id [[simdgroup_index_in_threadgroup]])
787{
788    uint head = tgid;
789    if (head >= num_heads) return;
790    uint kv_head = head / (num_heads / num_kv_heads);
791
792    uint q_off = head * head_dim;
793
794    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
795    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
796    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
797    // most generation steps have short effective context.
798    threadgroup float scores[ATTN_SCORES_SIZE];  // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
799
800    for (uint s = simd_id; s < seq_len; s += 8) {
801        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
802        float dot = 0.0;
803        for (uint d = simd_lane; d < head_dim; d += 32) {
804            dot += q[q_off + d] * k_cache[k_off + d];
805        }
806        dot = simd_sum(dot);
807        if (simd_lane == 0) {
808            scores[s] = dot * fast::rsqrt(float(head_dim));
809        }
810    }
811    threadgroup_barrier(mem_flags::mem_threadgroup);
812
813    // Step 2: Softmax over scores (cooperative)
814    // Find max
815    float local_max = -INFINITY;
816    for (uint s = tid; s < seq_len; s += 256) {
817        local_max = max(local_max, scores[s]);
818    }
819    local_max = simd_max(local_max);
820    threadgroup float shared_max[8];
821    if (simd_lane == 0) shared_max[simd_id] = local_max;
822    threadgroup_barrier(mem_flags::mem_threadgroup);
823    if (tid == 0) {
824        float m = shared_max[0];
825        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
826        shared_max[0] = m;
827    }
828    threadgroup_barrier(mem_flags::mem_threadgroup);
829    float max_val = shared_max[0];
830
831    // Exp and sum
832    float local_sum = 0.0;
833    for (uint s = tid; s < seq_len; s += 256) {
834        scores[s] = fast::exp(scores[s] - max_val);
835        local_sum += scores[s];
836    }
837    local_sum = simd_sum(local_sum);
838    threadgroup float shared_sum[8];
839    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
840    threadgroup_barrier(mem_flags::mem_threadgroup);
841    if (tid == 0) {
842        float total = 0.0;
843        for (uint i = 0; i < 8; i++) total += shared_sum[i];
844        shared_sum[0] = 1.0 / total;
845    }
846    threadgroup_barrier(mem_flags::mem_threadgroup);
847    float inv_sum = shared_sum[0];
848
849    for (uint s = tid; s < seq_len; s += 256) {
850        scores[s] *= inv_sum;
851    }
852    threadgroup_barrier(mem_flags::mem_threadgroup);
853
854    // Step 3: Weighted sum of V: output = scores · V
855    // Each thread handles a range of head_dim dimensions.
856    // Process 4 sequence positions at a time for better ILP and reduced
857    // loop overhead (float4 score gather, 4 V loads per iteration).
858    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
859    uint v_stride = num_kv_heads * head_dim;
860    for (uint d = tid; d < head_dim; d += 256) {
861        float acc = 0.0;
862        uint v_base = kv_head * head_dim + d;
863        for (uint s = 0; s < seq_len4; s += 4) {
864            float sc0 = scores[s];
865            float sc1 = scores[s + 1];
866            float sc2 = scores[s + 2];
867            float sc3 = scores[s + 3];
868            acc += sc0 * v_cache[s * v_stride + v_base]
869                 + sc1 * v_cache[(s+1) * v_stride + v_base]
870                 + sc2 * v_cache[(s+2) * v_stride + v_base]
871                 + sc3 * v_cache[(s+3) * v_stride + v_base];
872        }
873        for (uint s = seq_len4; s < seq_len; s++) {
874            acc += scores[s] * v_cache[s * v_stride + v_base];
875        }
876        output[q_off + d] = acc;
877    }
878}
879
880// ── Batched prefill kernels ────────────────────────────────────────────
881// These kernels process M input vectors against the same weight matrix
882// in a single dispatch, converting mat-vec into mat-mat for better GPU
883// utilization during prompt prefill.
884
885// ── rms_norm_batch ─────────────────────────────────────────────────────
886// RMS normalization for a batch of vectors.
887// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
888// Grid: M threadgroups (one per token).
889kernel void rms_norm_batch(
890    device const float* input   [[buffer(0)]],  // [M, n]
891    device const float* weight  [[buffer(1)]],  // [n]
892    device float* output        [[buffer(2)]],  // [M, n]
893    constant uint& n            [[buffer(3)]],
894    constant float& eps         [[buffer(4)]],
895    constant uint& num_tokens   [[buffer(5)]],
896    uint tgid [[threadgroup_position_in_grid]],
897    uint tid [[thread_index_in_threadgroup]])
898{
899    if (tgid >= num_tokens) return;
900
901    uint base = tgid * n;
902
903    float sum_sq = 0.0f;
904    for (uint i = tid; i < n; i += 256) {
905        float v = input[base + i];
906        sum_sq += v * v;
907    }
908
909    sum_sq = simd_sum(sum_sq);
910
911    threadgroup float shared[8];
912    uint simd_id = tid / 32;
913    uint simd_lane = tid % 32;
914    if (simd_lane == 0) {
915        shared[simd_id] = sum_sq;
916    }
917    threadgroup_barrier(mem_flags::mem_threadgroup);
918
919    if (tid == 0) {
920        float total = 0.0f;
921        for (uint i = 0; i < 8; i++) {
922            total += shared[i];
923        }
924        shared[0] = fast::rsqrt(total / float(n) + eps);
925    }
926    threadgroup_barrier(mem_flags::mem_threadgroup);
927
928    float inv_rms = shared[0];
929
930    for (uint i = tid; i < n; i += 256) {
931        output[base + i] = input[base + i] * inv_rms * weight[i];
932    }
933}
934
935// ── rope_batch ─────────────────────────────────────────────────────────
936// Rotary Position Embedding for a batch of vectors with different positions.
937// data layout: [M, num_heads * head_dim], positions: [M]
938// Each thread handles one (token, head, pair) combination.
939kernel void rope_batch(
940    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
941    constant uint& num_heads     [[buffer(1)]],
942    constant uint& head_dim      [[buffer(2)]],
943    device const uint* positions  [[buffer(3)]],  // [M] position per token
944    constant float& theta        [[buffer(4)]],
945    constant uint& num_tokens    [[buffer(5)]],
946    uint id [[thread_position_in_grid]])
947{
948    uint half_dim = head_dim / 2;
949    uint pairs_per_token = num_heads * half_dim;
950    uint total = num_tokens * pairs_per_token;
951    if (id >= total) return;
952
953    uint token = id / pairs_per_token;
954    uint rem = id % pairs_per_token;
955    uint h = rem / half_dim;
956    uint i = rem % half_dim;
957    uint off = token * (num_heads * head_dim) + h * head_dim;
958
959    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
960    float angle = float(positions[token]) * freq;
961    float c = cos(angle);
962    float s = sin(angle);
963
964    float x0 = data[off + 2 * i];
965    float x1 = data[off + 2 * i + 1];
966    data[off + 2 * i]     = x0 * c - x1 * s;
967    data[off + 2 * i + 1] = x0 * s + x1 * c;
968}
969
970// ── silu_mul_fused_batch ───────────────────────────────────────────────
971// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
972// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
973kernel void silu_mul_fused_batch(
974    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
975    device float* output        [[buffer(1)]],  // [M, n]
976    constant uint& n            [[buffer(2)]],
977    constant uint& num_tokens   [[buffer(3)]],
978    uint id [[thread_position_in_grid]])
979{
980    uint total = num_tokens * n;
981    if (id >= total) return;
982    uint token = id / n;
983    uint i = id % n;
984    uint gu_base = token * 2 * n;
985    float g = gate_up[gu_base + i];
986    float u = gate_up[gu_base + n + i];
987    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
988}
989
990// ── add_inplace_batch ──────────────────────────────────────────────────
991// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
992kernel void add_inplace_batch(
993    device float* a        [[buffer(0)]],  // [M * n]
994    device const float* b  [[buffer(1)]],  // [M * n]
995    constant uint& total   [[buffer(2)]],  // M * n
996    uint id [[thread_position_in_grid]])
997{
998    if (id >= total) return;
999    a[id] += b[id];
1000}
1001
1002// ── copy_embedding_batch ───────────────────────────────────────────────
1003// Copy M embedding rows from embedding table to a contiguous batch buffer.
1004// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1005kernel void copy_embedding_batch(
1006    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1007    device float* output        [[buffer(1)]],  // [M, dim]
1008    device const uint* tokens   [[buffer(2)]],  // [M]
1009    constant uint& dim          [[buffer(3)]],
1010    constant uint& num_tokens   [[buffer(4)]],
1011    uint id [[thread_position_in_grid]])
1012{
1013    uint total = num_tokens * dim;
1014    if (id >= total) return;
1015    uint token_idx = id / dim;
1016    uint d = id % dim;
1017    output[id] = embed[tokens[token_idx] * dim + d];
1018}
1019
1020// ── matmul_vec_batch ───────────────────────────────────────────────────
1021// Batched matrix-vector multiply: process M input vectors against the same
1022// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1023// Each threadgroup handles one (token, row_group) pair.
1024kernel void matmul_vec_batch(
1025    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1026    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1027    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1028    constant uint& num_tokens   [[buffer(3)]],  // M
1029    constant uint& rows         [[buffer(4)]],
1030    constant uint& cols         [[buffer(5)]],
1031    uint tgid [[threadgroup_position_in_grid]],
1032    uint tid [[thread_index_in_threadgroup]],
1033    uint simd_lane [[thread_index_in_simdgroup]],
1034    uint simd_id [[simdgroup_index_in_threadgroup]])
1035{
1036    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1037    uint token = tgid / row_tgs;
1038    uint tg_in_token = tgid % row_tgs;
1039    if (token >= num_tokens) return;
1040
1041    // Load this token's input vector into shared memory
1042    threadgroup float vec_tile[VEC_TILE_SIZE];
1043    device const float* input = inputs + token * cols;
1044    for (uint i = tid; i < cols; i += 256) {
1045        vec_tile[i] = input[i];
1046    }
1047    threadgroup_barrier(mem_flags::mem_threadgroup);
1048
1049    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1050    if (row_base >= rows) return;
1051
1052    uint base0 = row_base * cols;
1053    uint base1 = (row_base + 1) * cols;
1054    uint base2 = (row_base + 2) * cols;
1055    uint base3 = (row_base + 3) * cols;
1056    uint base4 = (row_base + 4) * cols;
1057    uint base5 = (row_base + 5) * cols;
1058    uint base6 = (row_base + 6) * cols;
1059    uint base7 = (row_base + 7) * cols;
1060
1061    uint cols_vec4 = cols & ~127u;
1062    float4 sum4_0 = float4(0.0f);
1063    float4 sum4_1 = float4(0.0f);
1064    float4 sum4_2 = float4(0.0f);
1065    float4 sum4_3 = float4(0.0f);
1066    float4 sum4_4 = float4(0.0f);
1067    float4 sum4_5 = float4(0.0f);
1068    float4 sum4_6 = float4(0.0f);
1069    float4 sum4_7 = float4(0.0f);
1070
1071    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1072        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1073        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1074        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1075        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1076        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1077        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1078        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1079        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1080        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1081    }
1082
1083    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1084    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1085    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1086    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1087    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1088    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1089    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1090    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1091
1092    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1093        float vv = vec_tile[j];
1094        sum0 += matrix[base0 + j] * vv;
1095        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1096        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1097        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1098        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1099        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1100        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1101        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1102    }
1103
1104    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1105    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1106    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1107    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1108
1109    device float* output = outputs + token * rows;
1110    if (simd_lane == 0) {
1111        if (row_base     < rows) output[row_base]     = sum0;
1112        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1113        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1114        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1115        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1116        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1117        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1118        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1119    }
1120}
1121
1122// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1123// Batched Q8_0 matrix-vector multiply for M input vectors.
1124// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1125kernel void matmul_vec_q8_batch(
1126    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1127    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1128    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1129    constant uint& num_tokens    [[buffer(3)]],  // M
1130    constant uint& rows          [[buffer(4)]],
1131    constant uint& cols          [[buffer(5)]],
1132    uint tgid [[threadgroup_position_in_grid]],
1133    uint tid [[thread_index_in_threadgroup]],
1134    uint simd_lane [[thread_index_in_simdgroup]],
1135    uint simd_id [[simdgroup_index_in_threadgroup]])
1136{
1137    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1138    uint token = tgid / row_tgs;
1139    uint tg_in_token = tgid % row_tgs;
1140    if (token >= num_tokens) return;
1141
1142    // Load this token's input vector into shared memory
1143    threadgroup float vec_tile[VEC_TILE_SIZE];
1144    device const float* input = inputs + token * cols;
1145    for (uint i = tid; i < cols; i += 256) {
1146        vec_tile[i] = input[i];
1147    }
1148    threadgroup_barrier(mem_flags::mem_threadgroup);
1149
1150    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1151    if (row_base >= rows) return;
1152
1153    uint blocks_per_row = cols / 32;
1154    uint row_bytes = blocks_per_row * 34;
1155
1156    device const uchar* r0 = matrix + row_base * row_bytes;
1157    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1158    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1159    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1160
1161    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1162
1163    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1164        uint bb = blk * 34;
1165        uint vb = blk * 32;
1166
1167        float sc0 = float(*(device const half*)(r0 + bb));
1168        float sc1 = float(*(device const half*)(r1 + bb));
1169        float sc2 = float(*(device const half*)(r2 + bb));
1170        float sc3 = float(*(device const half*)(r3 + bb));
1171
1172        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1173        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1174        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1175        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1176        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1177        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1178
1179        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1180        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1181        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1182        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1183        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1184        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1185        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1186        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1187
1188        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1189            short4 _s = short4(SHORT4); \
1190            char2 _a = as_type<char2>(_s.x); \
1191            char2 _b = as_type<char2>(_s.y); \
1192            char2 _c = as_type<char2>(_s.z); \
1193            char2 _d = as_type<char2>(_s.w); \
1194            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1195            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1196        }
1197
1198        float4 f0, f1;
1199        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1200
1201        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1202        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1203        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1204        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1205
1206        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1207        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1208        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1209        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1210
1211        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1212        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1213        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1214        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1215
1216        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1217        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1218        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1219        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1220
1221        #undef Q8_UNPACK8
1222
1223        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1224    }
1225
1226    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1227    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1228
1229    device float* output = outputs + token * rows;
1230    if (simd_lane == 0) {
1231        if (row_base     < rows) output[row_base]     = sum0;
1232        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1233        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1234        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1235    }
1236}
1237
1238// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1239// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1240// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1241// the Q8_0 weight blocks are fetched once from device memory and reused for
1242// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1243// per-token dispatch).
1244//
1245// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1246// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1247// with simd_sum.  Token vectors are read directly from device memory inside
1248// the block loop (not cached in shared memory) so intermediate_size up to
1249// 8192 fits without spilling threadgroup memory.
1250constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1251
1252kernel void matmul_q8_gemm_batch(
1253    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1254    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1255    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1256    constant uint& num_tokens    [[buffer(3)]],  // M
1257    constant uint& rows          [[buffer(4)]],
1258    constant uint& cols          [[buffer(5)]],
1259    uint2 tgid [[threadgroup_position_in_grid]],
1260    uint tid [[thread_index_in_threadgroup]],
1261    uint simd_lane [[thread_index_in_simdgroup]],
1262    uint simd_id [[simdgroup_index_in_threadgroup]])
1263{
1264    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1265    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1266    if (row_base >= rows || tok_base >= num_tokens) return;
1267
1268    // How many tokens in this tile are valid?
1269    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1270
1271    uint blocks_per_row = cols / 32;
1272    uint row_bytes = blocks_per_row * 34;
1273
1274    device const uchar* r0 = matrix + row_base * row_bytes;
1275    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1276    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1277    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1278
1279    // Accumulators: 4 tokens × 4 rows per simdgroup.
1280    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1281    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1282    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1283    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1284
1285    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1286        uint bb = blk * 34;
1287        uint vb = blk * 32;
1288
1289        // ── Load weight data ONCE per block (reused across all tokens) ──
1290        float sc0 = float(*(device const half*)(r0 + bb));
1291        float sc1 = float(*(device const half*)(r1 + bb));
1292        float sc2 = float(*(device const half*)(r2 + bb));
1293        float sc3 = float(*(device const half*)(r3 + bb));
1294
1295        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1296        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1297        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1298        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1299
1300        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1301            short4 _s = short4(SHORT4); \
1302            char2 _a = as_type<char2>(_s.x); \
1303            char2 _b = as_type<char2>(_s.y); \
1304            char2 _c = as_type<char2>(_s.z); \
1305            char2 _d = as_type<char2>(_s.w); \
1306            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1307            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1308        }
1309
1310        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1311        // registers for the duration of the block and are dotted against
1312        // every token's vector tile.
1313        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1314        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1315        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1316        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1317
1318        Q8_UNPACK8(d0[0], w0_0, w0_1);
1319        Q8_UNPACK8(d0[1], w0_2, w0_3);
1320        Q8_UNPACK8(d0[2], w0_4, w0_5);
1321        Q8_UNPACK8(d0[3], w0_6, w0_7);
1322
1323        Q8_UNPACK8(d1[0], w1_0, w1_1);
1324        Q8_UNPACK8(d1[1], w1_2, w1_3);
1325        Q8_UNPACK8(d1[2], w1_4, w1_5);
1326        Q8_UNPACK8(d1[3], w1_6, w1_7);
1327
1328        Q8_UNPACK8(d2[0], w2_0, w2_1);
1329        Q8_UNPACK8(d2[1], w2_2, w2_3);
1330        Q8_UNPACK8(d2[2], w2_4, w2_5);
1331        Q8_UNPACK8(d2[3], w2_6, w2_7);
1332
1333        Q8_UNPACK8(d3[0], w3_0, w3_1);
1334        Q8_UNPACK8(d3[1], w3_2, w3_3);
1335        Q8_UNPACK8(d3[2], w3_4, w3_5);
1336        Q8_UNPACK8(d3[3], w3_6, w3_7);
1337
1338        #undef Q8_UNPACK8
1339
1340        // ── For each token, read vector and accumulate against shared weights ──
1341        // Token 0 (always valid: tok_count >= 1).
1342        {
1343            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1344            float4 v0 = *(device const float4*)(a0);
1345            float4 v1 = *(device const float4*)(a0 + 4);
1346            float4 v2 = *(device const float4*)(a0 + 8);
1347            float4 v3 = *(device const float4*)(a0 + 12);
1348            float4 v4 = *(device const float4*)(a0 + 16);
1349            float4 v5 = *(device const float4*)(a0 + 20);
1350            float4 v6 = *(device const float4*)(a0 + 24);
1351            float4 v7 = *(device const float4*)(a0 + 28);
1352            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1353                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1354            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1355                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1356            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1357                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1358            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1359                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1360            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1361        }
1362        // Token 1
1363        if (tok_count > 1) {
1364            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1365            float4 v0 = *(device const float4*)(a1);
1366            float4 v1 = *(device const float4*)(a1 + 4);
1367            float4 v2 = *(device const float4*)(a1 + 8);
1368            float4 v3 = *(device const float4*)(a1 + 12);
1369            float4 v4 = *(device const float4*)(a1 + 16);
1370            float4 v5 = *(device const float4*)(a1 + 20);
1371            float4 v6 = *(device const float4*)(a1 + 24);
1372            float4 v7 = *(device const float4*)(a1 + 28);
1373            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1374                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1375            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1376                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1377            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1378                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1379            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1380                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1381            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1382        }
1383        // Token 2
1384        if (tok_count > 2) {
1385            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1386            float4 v0 = *(device const float4*)(a2);
1387            float4 v1 = *(device const float4*)(a2 + 4);
1388            float4 v2 = *(device const float4*)(a2 + 8);
1389            float4 v3 = *(device const float4*)(a2 + 12);
1390            float4 v4 = *(device const float4*)(a2 + 16);
1391            float4 v5 = *(device const float4*)(a2 + 20);
1392            float4 v6 = *(device const float4*)(a2 + 24);
1393            float4 v7 = *(device const float4*)(a2 + 28);
1394            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1395                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1396            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1397                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1398            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1399                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1400            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1401                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1402            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1403        }
1404        // Token 3
1405        if (tok_count > 3) {
1406            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1407            float4 v0 = *(device const float4*)(a3);
1408            float4 v1 = *(device const float4*)(a3 + 4);
1409            float4 v2 = *(device const float4*)(a3 + 8);
1410            float4 v3 = *(device const float4*)(a3 + 12);
1411            float4 v4 = *(device const float4*)(a3 + 16);
1412            float4 v5 = *(device const float4*)(a3 + 20);
1413            float4 v6 = *(device const float4*)(a3 + 24);
1414            float4 v7 = *(device const float4*)(a3 + 28);
1415            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1416                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1417            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1418                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1419            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1420                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1421            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1422                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1423            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1424        }
1425    }
1426
1427    // simdgroup reduction
1428    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1429    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1430    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1431    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1432
1433    if (simd_lane == 0) {
1434        device float* o0 = outputs + (tok_base + 0) * rows;
1435        if (row_base     < rows) o0[row_base]     = s00;
1436        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1437        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1438        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1439
1440        if (tok_count > 1) {
1441            device float* o1 = outputs + (tok_base + 1) * rows;
1442            if (row_base     < rows) o1[row_base]     = s10;
1443            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1444            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1445            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1446        }
1447        if (tok_count > 2) {
1448            device float* o2 = outputs + (tok_base + 2) * rows;
1449            if (row_base     < rows) o2[row_base]     = s20;
1450            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1451            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1452            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1453        }
1454        if (tok_count > 3) {
1455            device float* o3 = outputs + (tok_base + 3) * rows;
1456            if (row_base     < rows) o3[row_base]     = s30;
1457            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1458            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1459            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1460        }
1461    }
1462}
1463
1464// ── matmul_q8_mma ──────────────────────────────────────────────────────
1465// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1466// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1467// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1468// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1469//
1470// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1471// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1472// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1473// dequantized into threadgroup memory once per block and reused by all
1474// simdgroups in the tile.
1475//
1476// Assumptions (verified in the dispatch helper, falls back otherwise):
1477//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1478//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1479//   * num_tokens may be any value; partial row at the tile boundary is handled
1480//     via a scratch copy path.
1481constant constexpr uint MMA_TOK_TILE = 16;
1482constant constexpr uint MMA_ROW_TILE = 16;
1483
1484kernel void matmul_q8_mma(
1485    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1486    device const float* inputs   [[buffer(1)]],  // [M, cols]
1487    device float* outputs        [[buffer(2)]],  // [M, rows]
1488    constant uint& num_tokens    [[buffer(3)]],
1489    constant uint& rows          [[buffer(4)]],
1490    constant uint& cols          [[buffer(5)]],
1491    uint2 tgid [[threadgroup_position_in_grid]],
1492    uint tid [[thread_index_in_threadgroup]],
1493    uint simd_id [[simdgroup_index_in_threadgroup]])
1494{
1495    uint row_base = tgid.x * MMA_ROW_TILE;
1496    uint tok_base = tgid.y * MMA_TOK_TILE;
1497    if (row_base >= rows || tok_base >= num_tokens) return;
1498
1499    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1500    threadgroup float w_tile[MMA_ROW_TILE * 32];
1501    threadgroup float t_tile[MMA_TOK_TILE * 32];
1502
1503    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1504    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1505    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1506
1507    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1508
1509    uint blocks_per_row = cols / 32;
1510    uint row_bytes = blocks_per_row * 34;
1511
1512    for (uint blk = 0; blk < blocks_per_row; blk++) {
1513        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1514        // 512 floats / 128 threads = 4 floats per thread.
1515        {
1516            uint base = tid * 4;
1517            for (uint ii = 0; ii < 4; ii++) {
1518                uint idx = base + ii;
1519                uint r = idx / 32;
1520                uint k = idx % 32;
1521                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1522                float sc = float(*(device const half*)rp);
1523                int ival = int(*(device const int8_t*)(rp + 2 + k));
1524                w_tile[r * 32 + k] = float(ival) * sc;
1525            }
1526        }
1527
1528        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1529        {
1530            uint base = tid * 4;
1531            for (uint ii = 0; ii < 4; ii++) {
1532                uint idx = base + ii;
1533                uint m = idx / 32;
1534                uint k = idx % 32;
1535                uint tok = tok_base + m;
1536                t_tile[m * 32 + k] = (tok < num_tokens)
1537                    ? inputs[tok * cols + blk * 32 + k]
1538                    : 0.0f;
1539            }
1540        }
1541
1542        threadgroup_barrier(mem_flags::mem_threadgroup);
1543
1544        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1545        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1546        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1547        // C[m, r] += A[m, k] * B[k, r]
1548        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1549            simdgroup_matrix<float, 8, 8> A, B;
1550            simdgroup_load(A,
1551                t_tile + sg_tok_base * 32 + k_sub * 8,
1552                32,
1553                ulong2(0, 0),
1554                false);
1555            simdgroup_load(B,
1556                w_tile + sg_row_base * 32 + k_sub * 8,
1557                32,
1558                ulong2(0, 0),
1559                true);
1560            simdgroup_multiply_accumulate(C, A, B, C);
1561        }
1562
1563        threadgroup_barrier(mem_flags::mem_threadgroup);
1564    }
1565
1566    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1567    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1568    uint out_tok = tok_base + sg_tok_base;
1569    uint out_row = row_base + sg_row_base;
1570    bool full_tok = (out_tok + 8 <= num_tokens);
1571    if (full_tok) {
1572        // Fast path: entire 8×8 sub-tile is in-bounds.
1573        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1574    } else if (out_tok < num_tokens) {
1575        // Partial row at the last token tile: stage in per-simdgroup scratch
1576        // and scalar-copy the valid rows.
1577        threadgroup float scratch[4 * 64];
1578        simdgroup_store(C, scratch + simd_id * 64, 8);
1579        simdgroup_barrier(mem_flags::mem_threadgroup);
1580        uint lane = tid % 32;
1581        if (lane == 0) {
1582            uint valid = num_tokens - out_tok;  // 1..7
1583            for (uint m = 0; m < valid; m++) {
1584                device float* dst = outputs + (out_tok + m) * rows + out_row;
1585                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1586                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1587                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1588            }
1589        }
1590    }
1591}
1592
1593// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1594// Larger-tile variant of matmul_q8_mma for long-context prefill.
1595//
1596// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1597// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1598// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1599//
1600//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1601//     output sub-tiles (tok, row):
1602//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1603//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1604//
1605// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1606// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1607// kernel — and halves the number of threadgroups vs the 16×16 tile.
1608//
1609// Assumptions (verified in dispatch helper, fallback otherwise):
1610//   * cols % 32 == 0
1611//   * rows % 32 == 0
1612constant constexpr uint MMA32_TOK_TILE = 32;
1613constant constexpr uint MMA32_ROW_TILE = 32;
1614
1615kernel void matmul_q8_mma32(
1616    device const uchar* matrix   [[buffer(0)]],
1617    device const float* inputs   [[buffer(1)]],
1618    device float* outputs        [[buffer(2)]],
1619    constant uint& num_tokens    [[buffer(3)]],
1620    constant uint& rows          [[buffer(4)]],
1621    constant uint& cols          [[buffer(5)]],
1622    uint2 tgid [[threadgroup_position_in_grid]],
1623    uint tid [[thread_index_in_threadgroup]],
1624    uint simd_id [[simdgroup_index_in_threadgroup]])
1625{
1626    uint row_base = tgid.x * MMA32_ROW_TILE;
1627    uint tok_base = tgid.y * MMA32_TOK_TILE;
1628    if (row_base >= rows || tok_base >= num_tokens) return;
1629
1630    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1631    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1632    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1633
1634    uint sg_tok_idx  = simd_id / 2;      // 0..3
1635    uint sg_row_half = simd_id % 2;      // 0..1
1636    uint sg_tok_base = sg_tok_idx * 8;
1637    uint sg_row_base_a = sg_row_half * 16 + 0;
1638    uint sg_row_base_b = sg_row_half * 16 + 8;
1639
1640    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1641    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1642
1643    uint blocks_per_row = cols / 32;
1644    uint row_bytes = blocks_per_row * 34;
1645
1646    for (uint blk = 0; blk < blocks_per_row; blk++) {
1647        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1648        {
1649            uint base = tid * 4;
1650            for (uint ii = 0; ii < 4; ii++) {
1651                uint idx = base + ii;
1652                uint r = idx / 32;
1653                uint k = idx % 32;
1654                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1655                float sc = float(*(device const half*)rp);
1656                int ival = int(*(device const int8_t*)(rp + 2 + k));
1657                w_tile[r * 32 + k] = float(ival) * sc;
1658            }
1659        }
1660
1661        // Cooperative token tile load.
1662        {
1663            uint base = tid * 4;
1664            for (uint ii = 0; ii < 4; ii++) {
1665                uint idx = base + ii;
1666                uint m = idx / 32;
1667                uint k = idx % 32;
1668                uint tok = tok_base + m;
1669                t_tile[m * 32 + k] = (tok < num_tokens)
1670                    ? inputs[tok * cols + blk * 32 + k]
1671                    : 0.0f;
1672            }
1673        }
1674
1675        threadgroup_barrier(mem_flags::mem_threadgroup);
1676
1677        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1678        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1679            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1680            simdgroup_load(A,
1681                t_tile + sg_tok_base * 32 + k_sub * 8,
1682                32,
1683                ulong2(0, 0),
1684                false);
1685            simdgroup_load(B_a,
1686                w_tile + sg_row_base_a * 32 + k_sub * 8,
1687                32,
1688                ulong2(0, 0),
1689                true);
1690            simdgroup_load(B_b,
1691                w_tile + sg_row_base_b * 32 + k_sub * 8,
1692                32,
1693                ulong2(0, 0),
1694                true);
1695            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1696            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1697        }
1698
1699        threadgroup_barrier(mem_flags::mem_threadgroup);
1700    }
1701
1702    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1703    // (verified in dispatch), so full simdgroup_store is safe for the row
1704    // dimension; only the last token tile may be partial.
1705    uint out_tok = tok_base + sg_tok_base;
1706    uint out_row_a = row_base + sg_row_base_a;
1707    uint out_row_b = row_base + sg_row_base_b;
1708    bool full_tok = (out_tok + 8 <= num_tokens);
1709    if (full_tok) {
1710        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1711        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1712    } else if (out_tok < num_tokens) {
1713        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1714        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1715        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1716        simdgroup_barrier(mem_flags::mem_threadgroup);
1717        uint lane = tid % 32;
1718        if (lane == 0) {
1719            uint valid = num_tokens - out_tok;  // 1..7
1720            for (uint m = 0; m < valid; m++) {
1721                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1722                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1723                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1724                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1725                for (uint j = 0; j < 8; j++) {
1726                    dst_a[j] = src_a[j];
1727                    dst_b[j] = src_b[j];
1728                }
1729            }
1730        }
1731    }
1732}
1733
1734// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1735// FP16 threadgroup-tile variant of matmul_q8_mma32.
1736//
1737// Stores dequantized weights and token inputs as `half` in threadgroup
1738// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1739// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1740// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1741// representation preserves the full quantized dynamic range.  Token
1742// activations stay numerically safe because the subsequent
1743// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1744//
1745// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1746// accumulators each.  Primary win vs mma32 is occupancy at moderate
1747// prefill lengths where the GPU is wave-starved.
1748kernel void matmul_q8_mma32_h(
1749    device const uchar* matrix   [[buffer(0)]],
1750    device const float* inputs   [[buffer(1)]],
1751    device float* outputs        [[buffer(2)]],
1752    constant uint& num_tokens    [[buffer(3)]],
1753    constant uint& rows          [[buffer(4)]],
1754    constant uint& cols          [[buffer(5)]],
1755    uint2 tgid [[threadgroup_position_in_grid]],
1756    uint tid [[thread_index_in_threadgroup]],
1757    uint simd_id [[simdgroup_index_in_threadgroup]])
1758{
1759    uint row_base = tgid.x * MMA32_ROW_TILE;
1760    uint tok_base = tgid.y * MMA32_TOK_TILE;
1761    if (row_base >= rows || tok_base >= num_tokens) return;
1762
1763    // 32×32 half tiles — 2 KB each, 4 KB total.
1764    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1765    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1766
1767    uint sg_tok_idx  = simd_id / 2;
1768    uint sg_row_half = simd_id % 2;
1769    uint sg_tok_base = sg_tok_idx * 8;
1770    uint sg_row_base_a = sg_row_half * 16 + 0;
1771    uint sg_row_base_b = sg_row_half * 16 + 8;
1772
1773    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1774    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1775
1776    uint blocks_per_row = cols / 32;
1777    uint row_bytes = blocks_per_row * 34;
1778
1779    for (uint blk = 0; blk < blocks_per_row; blk++) {
1780        // Cooperative weight dequantization to FP16.
1781        {
1782            uint base = tid * 4;
1783            for (uint ii = 0; ii < 4; ii++) {
1784                uint idx = base + ii;
1785                uint r = idx / 32;
1786                uint k = idx % 32;
1787                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1788                float sc = float(*(device const half*)rp);
1789                int ival = int(*(device const int8_t*)(rp + 2 + k));
1790                w_tile[r * 32 + k] = half(float(ival) * sc);
1791            }
1792        }
1793
1794        // Cooperative token tile load (f32 → f16 narrowing).
1795        {
1796            uint base = tid * 4;
1797            for (uint ii = 0; ii < 4; ii++) {
1798                uint idx = base + ii;
1799                uint m = idx / 32;
1800                uint k = idx % 32;
1801                uint tok = tok_base + m;
1802                t_tile[m * 32 + k] = (tok < num_tokens)
1803                    ? half(inputs[tok * cols + blk * 32 + k])
1804                    : half(0);
1805            }
1806        }
1807
1808        threadgroup_barrier(mem_flags::mem_threadgroup);
1809
1810        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1811            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1812            simdgroup_load(A,
1813                t_tile + sg_tok_base * 32 + k_sub * 8,
1814                32,
1815                ulong2(0, 0),
1816                false);
1817            simdgroup_load(B_a,
1818                w_tile + sg_row_base_a * 32 + k_sub * 8,
1819                32,
1820                ulong2(0, 0),
1821                true);
1822            simdgroup_load(B_b,
1823                w_tile + sg_row_base_b * 32 + k_sub * 8,
1824                32,
1825                ulong2(0, 0),
1826                true);
1827            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1828            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1829        }
1830
1831        threadgroup_barrier(mem_flags::mem_threadgroup);
1832    }
1833
1834    uint out_tok = tok_base + sg_tok_base;
1835    uint out_row_a = row_base + sg_row_base_a;
1836    uint out_row_b = row_base + sg_row_base_b;
1837    bool full_tok = (out_tok + 8 <= num_tokens);
1838    if (full_tok) {
1839        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1840        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1841    } else if (out_tok < num_tokens) {
1842        threadgroup float scratch[8 * 2 * 64];
1843        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1844        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1845        simdgroup_barrier(mem_flags::mem_threadgroup);
1846        uint lane = tid % 32;
1847        if (lane == 0) {
1848            uint valid = num_tokens - out_tok;
1849            for (uint m = 0; m < valid; m++) {
1850                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1851                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1852                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1853                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1854                for (uint j = 0; j < 8; j++) {
1855                    dst_a[j] = src_a[j];
1856                    dst_b[j] = src_b[j];
1857                }
1858            }
1859        }
1860    }
1861}
1862
1863// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1864// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1865//
1866// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1867// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1868//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1869//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1870// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1871//
1872// Per K_sub iteration: load two A fragments and two B fragments, then run
1873// **four** MMA instructions reusing A_top with both B's and A_bot with
1874// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1875// 2-accumulator kernel and halves the thread count per threadgroup (128
1876// threads), which often improves occupancy on Apple GPUs where the
1877// concurrent-thread budget is the tighter limit than shared-memory size.
1878kernel void matmul_q8_mma32_h4(
1879    device const uchar* matrix   [[buffer(0)]],
1880    device const float* inputs   [[buffer(1)]],
1881    device float* outputs        [[buffer(2)]],
1882    constant uint& num_tokens    [[buffer(3)]],
1883    constant uint& rows          [[buffer(4)]],
1884    constant uint& cols          [[buffer(5)]],
1885    uint2 tgid [[threadgroup_position_in_grid]],
1886    uint tid [[thread_index_in_threadgroup]],
1887    uint simd_id [[simdgroup_index_in_threadgroup]])
1888{
1889    uint row_base = tgid.x * MMA32_ROW_TILE;
1890    uint tok_base = tgid.y * MMA32_TOK_TILE;
1891    if (row_base >= rows || tok_base >= num_tokens) return;
1892
1893    // 32×32 FP16 tiles, 4 KB total.
1894    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1895    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1896
1897    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1898    uint sg_tok_q = simd_id / 2;   // 0..1
1899    uint sg_row_q = simd_id % 2;   // 0..1
1900    uint sg_tok_base = sg_tok_q * 16;
1901    uint sg_row_base = sg_row_q * 16;
1902
1903    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1904    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1905    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1906    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1907
1908    uint blocks_per_row = cols / 32;
1909    uint row_bytes = blocks_per_row * 34;
1910
1911    for (uint blk = 0; blk < blocks_per_row; blk++) {
1912        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1913        {
1914            uint base = tid * 8;
1915            for (uint ii = 0; ii < 8; ii++) {
1916                uint idx = base + ii;
1917                uint r = idx / 32;
1918                uint k = idx % 32;
1919                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1920                float sc = float(*(device const half*)rp);
1921                int ival = int(*(device const int8_t*)(rp + 2 + k));
1922                w_tile[r * 32 + k] = half(float(ival) * sc);
1923            }
1924        }
1925
1926        // Cooperative token tile load.
1927        {
1928            uint base = tid * 8;
1929            for (uint ii = 0; ii < 8; ii++) {
1930                uint idx = base + ii;
1931                uint m = idx / 32;
1932                uint k = idx % 32;
1933                uint tok = tok_base + m;
1934                t_tile[m * 32 + k] = (tok < num_tokens)
1935                    ? half(inputs[tok * cols + blk * 32 + k])
1936                    : half(0);
1937            }
1938        }
1939
1940        threadgroup_barrier(mem_flags::mem_threadgroup);
1941
1942        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1943        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1944            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1945            simdgroup_load(A_top,
1946                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1947                32,
1948                ulong2(0, 0),
1949                false);
1950            simdgroup_load(A_bot,
1951                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1952                32,
1953                ulong2(0, 0),
1954                false);
1955            simdgroup_load(B_lo,
1956                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1957                32,
1958                ulong2(0, 0),
1959                true);
1960            simdgroup_load(B_hi,
1961                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1962                32,
1963                ulong2(0, 0),
1964                true);
1965            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1966            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1967            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1968            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1969        }
1970
1971        threadgroup_barrier(mem_flags::mem_threadgroup);
1972    }
1973
1974    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
1975    uint out_tok_top = tok_base + sg_tok_base + 0;
1976    uint out_tok_bot = tok_base + sg_tok_base + 8;
1977    uint out_row_lo  = row_base + sg_row_base + 0;
1978    uint out_row_hi  = row_base + sg_row_base + 8;
1979    bool full = (out_tok_bot + 8 <= num_tokens);
1980    if (full) {
1981        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1982        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1983        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1984        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1985    } else {
1986        // Partial-token fallback via per-simdgroup scratch.
1987        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
1988        uint sg_base = simd_id * 256;
1989        simdgroup_store(C_00, scratch + sg_base +   0, 8);
1990        simdgroup_store(C_01, scratch + sg_base +  64, 8);
1991        simdgroup_store(C_10, scratch + sg_base + 128, 8);
1992        simdgroup_store(C_11, scratch + sg_base + 192, 8);
1993        simdgroup_barrier(mem_flags::mem_threadgroup);
1994        uint lane = tid % 32;
1995        if (lane == 0) {
1996            for (uint m = 0; m < 8; m++) {
1997                uint t_top = out_tok_top + m;
1998                if (t_top < num_tokens) {
1999                    device float* dst0 = outputs + t_top * rows + out_row_lo;
2000                    device float* dst1 = outputs + t_top * rows + out_row_hi;
2001                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
2002                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
2003                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2004                }
2005                uint t_bot = out_tok_bot + m;
2006                if (t_bot < num_tokens) {
2007                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2008                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2009                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2010                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2011                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2012                }
2013            }
2014        }
2015    }
2016}
2017
2018// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2019// All-half MMA variant of matmul_q8_mma32_h4.
2020//
2021// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2022// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2023// rate (dual-issue FMA), so if Q8_0 precision holds through half
2024// accumulation this kernel can double the effective FLOP throughput on
2025// matmul-bound prefill.
2026//
2027// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2028// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2029// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2030// precision on extreme values, but the inputs are already quantized so the
2031// per-product error floor is higher than the half-precision rounding error.
2032// We verify correctness on 135M / 1B / 3B before enabling.
2033kernel void matmul_q8_mma32_hh4(
2034    device const uchar* matrix   [[buffer(0)]],
2035    device const float* inputs   [[buffer(1)]],
2036    device float* outputs        [[buffer(2)]],
2037    constant uint& num_tokens    [[buffer(3)]],
2038    constant uint& rows          [[buffer(4)]],
2039    constant uint& cols          [[buffer(5)]],
2040    uint2 tgid [[threadgroup_position_in_grid]],
2041    uint tid [[thread_index_in_threadgroup]],
2042    uint simd_id [[simdgroup_index_in_threadgroup]])
2043{
2044    uint row_base = tgid.x * MMA32_ROW_TILE;
2045    uint tok_base = tgid.y * MMA32_TOK_TILE;
2046    if (row_base >= rows || tok_base >= num_tokens) return;
2047
2048    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2049    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2050
2051    uint sg_tok_q = simd_id / 2;
2052    uint sg_row_q = simd_id % 2;
2053    uint sg_tok_base = sg_tok_q * 16;
2054    uint sg_row_base = sg_row_q * 16;
2055
2056    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2057    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2058    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2059    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2060
2061    uint blocks_per_row = cols / 32;
2062    uint row_bytes = blocks_per_row * 34;
2063
2064    for (uint blk = 0; blk < blocks_per_row; blk++) {
2065        {
2066            uint base = tid * 8;
2067            for (uint ii = 0; ii < 8; ii++) {
2068                uint idx = base + ii;
2069                uint r = idx / 32;
2070                uint k = idx % 32;
2071                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2072                float sc = float(*(device const half*)rp);
2073                int ival = int(*(device const int8_t*)(rp + 2 + k));
2074                w_tile[r * 32 + k] = half(float(ival) * sc);
2075            }
2076        }
2077        {
2078            uint base = tid * 8;
2079            for (uint ii = 0; ii < 8; ii++) {
2080                uint idx = base + ii;
2081                uint m = idx / 32;
2082                uint k = idx % 32;
2083                uint tok = tok_base + m;
2084                t_tile[m * 32 + k] = (tok < num_tokens)
2085                    ? half(inputs[tok * cols + blk * 32 + k])
2086                    : half(0);
2087            }
2088        }
2089        threadgroup_barrier(mem_flags::mem_threadgroup);
2090
2091        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2092            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2093            simdgroup_load(A_top,
2094                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2095                32, ulong2(0, 0), false);
2096            simdgroup_load(A_bot,
2097                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2098                32, ulong2(0, 0), false);
2099            simdgroup_load(B_lo,
2100                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2101                32, ulong2(0, 0), true);
2102            simdgroup_load(B_hi,
2103                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2104                32, ulong2(0, 0), true);
2105            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2106            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2107            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2108            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2109        }
2110
2111        threadgroup_barrier(mem_flags::mem_threadgroup);
2112    }
2113
2114    // Store half accumulators via scratch (must widen to f32 for device output).
2115    uint out_tok_top = tok_base + sg_tok_base + 0;
2116    uint out_tok_bot = tok_base + sg_tok_base + 8;
2117    uint out_row_lo  = row_base + sg_row_base + 0;
2118    uint out_row_hi  = row_base + sg_row_base + 8;
2119
2120    threadgroup half scratch[4 * 4 * 64];
2121    uint sg_base = simd_id * 256;
2122    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2123    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2124    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2125    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2126    simdgroup_barrier(mem_flags::mem_threadgroup);
2127    uint lane = tid % 32;
2128    if (lane == 0) {
2129        for (uint m = 0; m < 8; m++) {
2130            uint t_top = out_tok_top + m;
2131            if (t_top < num_tokens) {
2132                device float* dst0 = outputs + t_top * rows + out_row_lo;
2133                device float* dst1 = outputs + t_top * rows + out_row_hi;
2134                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2135                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2136                for (uint j = 0; j < 8; j++) {
2137                    dst0[j] = float(src0[j]);
2138                    dst1[j] = float(src1[j]);
2139                }
2140            }
2141            uint t_bot = out_tok_bot + m;
2142            if (t_bot < num_tokens) {
2143                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2144                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2145                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2146                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2147                for (uint j = 0; j < 8; j++) {
2148                    dst2[j] = float(src2[j]);
2149                    dst3[j] = float(src3[j]);
2150                }
2151            }
2152        }
2153    }
2154}
2155
2156// ── add_bias_batch ─────────────────────────────────────────────────────
2157// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2158// Used for Qwen2 QKV bias after the fused qkv matmul.
2159//     out[token, i] += bias[i]    for i in 0..rows, token in 0..num_tokens
2160kernel void add_bias_batch(
2161    device float* out            [[buffer(0)]],  // [num_tokens, rows]
2162    device const float* bias     [[buffer(1)]],  // [rows]
2163    constant uint& num_tokens    [[buffer(2)]],
2164    constant uint& rows          [[buffer(3)]],
2165    uint id [[thread_position_in_grid]])
2166{
2167    uint total = num_tokens * rows;
2168    if (id >= total) return;
2169    uint i = id % rows;
2170    out[id] += bias[i];
2171}
2172
2173// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2174// Batched Q4_0 matrix-vector multiply for M input vectors.
2175// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2176kernel void matmul_vec_q4_batch(
2177    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2178    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2179    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2180    constant uint& num_tokens    [[buffer(3)]],  // M
2181    constant uint& rows          [[buffer(4)]],
2182    constant uint& cols          [[buffer(5)]],
2183    uint tgid [[threadgroup_position_in_grid]],
2184    uint tid [[thread_index_in_threadgroup]],
2185    uint simd_lane [[thread_index_in_simdgroup]],
2186    uint simd_id [[simdgroup_index_in_threadgroup]])
2187{
2188    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2189    uint token = tgid / row_tgs;
2190    uint tg_in_token = tgid % row_tgs;
2191    if (token >= num_tokens) return;
2192
2193    threadgroup float vec_tile[VEC_TILE_SIZE];
2194    device const float* input = inputs + token * cols;
2195    for (uint i = tid; i < cols; i += 256) {
2196        vec_tile[i] = input[i];
2197    }
2198    threadgroup_barrier(mem_flags::mem_threadgroup);
2199
2200    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2201    if (row_base >= rows) return;
2202
2203    uint blocks_per_row = cols / 32;
2204    uint row_bytes = blocks_per_row * 18;
2205
2206    device const uchar* r0 = matrix + row_base * row_bytes;
2207    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2208    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2209    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2210
2211    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2212
2213    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2214        uint bb = blk * 18;
2215        uint vb = blk * 32;
2216
2217        float sc0 = float(*(device const half*)(r0 + bb));
2218        float sc1 = float(*(device const half*)(r1 + bb));
2219        float sc2 = float(*(device const half*)(r2 + bb));
2220        float sc3 = float(*(device const half*)(r3 + bb));
2221
2222        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2223        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2224        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2225        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2226
2227        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2228        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2229        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2230        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2231        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2232        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2233        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2234        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2235
2236        float bd0=0, bd1=0, bd2=0, bd3=0;
2237        uchar4 b;
2238
2239        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2240        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2241        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2242        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2243
2244        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2245        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2246        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2247        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2248
2249        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2250        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2251        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2252        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2253
2254        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2255        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2256        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2257        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2258
2259        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2260    }
2261
2262    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2263    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2264
2265    device float* output = outputs + token * rows;
2266    if (simd_lane == 0) {
2267        if (row_base     < rows) output[row_base]     = sum0;
2268        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2269        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2270        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2271    }
2272}
2273
2274// ── copy_kv_batch ─────────────────────────────────────────────────────
2275// Copy K or V from a strided batch QKV buffer to the KV cache.
2276// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2277// dst layout: contiguous [max_seq, kv_dim] cache.
2278kernel void copy_kv_batch(
2279    device const float* src  [[buffer(0)]],  // batch QKV buffer
2280    device float* dst        [[buffer(1)]],  // KV cache
2281    constant uint& M         [[buffer(2)]],  // num tokens in batch
2282    constant uint& kv_dim    [[buffer(3)]],  // floats per KV vector
2283    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2284    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2285    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2286    uint id [[thread_position_in_grid]])
2287{
2288    uint total = M * kv_dim;
2289    if (id >= total) return;
2290    uint token = id / kv_dim;
2291    uint d = id % kv_dim;
2292    uint dst_off = (base_pos + token) * kv_dim + d;
2293    uint src_off = token * src_stride + src_offset + d;
2294    dst[dst_off] = src[src_off];
2295}
2296
2297// ── attention_batch ───────────────────────────────────────────────────
2298// Batched causal attention for prefill. Processes M tokens in one dispatch.
2299// Each threadgroup handles one (token, head) pair.
2300// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2301// Causal masking: token i can only attend to positions 0..base_pos+i.
2302kernel void attention_batch(
2303    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2304    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2305    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2306    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2307    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2308    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2309    constant uint& num_heads         [[buffer(6)]],
2310    constant uint& num_kv_heads      [[buffer(7)]],
2311    constant uint& head_dim          [[buffer(8)]],
2312    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2313    uint tgid [[threadgroup_position_in_grid]],
2314    uint tid [[thread_index_in_threadgroup]],
2315    uint simd_lane [[thread_index_in_simdgroup]],
2316    uint simd_id [[simdgroup_index_in_threadgroup]])
2317{
2318    // Grid: M * num_heads threadgroups
2319    uint token_idx = tgid / num_heads;
2320    uint head = tgid % num_heads;
2321    if (token_idx >= M) return;
2322
2323    uint kv_head = head / (num_heads / num_kv_heads);
2324    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2325
2326    // Q offset uses strided layout (from batch QKV buffer)
2327    uint q_off = token_idx * q_stride + head * head_dim;
2328    // Output is contiguous [M, num_heads * head_dim]
2329    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2330
2331    // Shared memory for attention scores — sized to the effective max_seq_len
2332    // (4096 for all supported models) so long-context attention doesn't overflow.
2333    threadgroup float scores[ATTN_SCORES_SIZE];
2334
2335    // Step 1: Q * K^T with simdgroup reduction
2336    for (uint s = simd_id; s < seq_len; s += 8) {
2337        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2338        float dot = 0.0;
2339        for (uint d = simd_lane; d < head_dim; d += 32) {
2340            dot += q_batch[q_off + d] * k_cache[k_off + d];
2341        }
2342        dot = simd_sum(dot);
2343        if (simd_lane == 0) {
2344            scores[s] = dot * fast::rsqrt(float(head_dim));
2345        }
2346    }
2347    threadgroup_barrier(mem_flags::mem_threadgroup);
2348
2349    // Step 2: Softmax (cooperative)
2350    float local_max = -INFINITY;
2351    for (uint s = tid; s < seq_len; s += 256) {
2352        local_max = max(local_max, scores[s]);
2353    }
2354    local_max = simd_max(local_max);
2355    threadgroup float shared_max[8];
2356    if (simd_lane == 0) shared_max[simd_id] = local_max;
2357    threadgroup_barrier(mem_flags::mem_threadgroup);
2358    if (tid == 0) {
2359        float m = shared_max[0];
2360        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2361        shared_max[0] = m;
2362    }
2363    threadgroup_barrier(mem_flags::mem_threadgroup);
2364    float max_val = shared_max[0];
2365
2366    float local_sum = 0.0;
2367    for (uint s = tid; s < seq_len; s += 256) {
2368        scores[s] = fast::exp(scores[s] - max_val);
2369        local_sum += scores[s];
2370    }
2371    local_sum = simd_sum(local_sum);
2372    threadgroup float shared_sum[8];
2373    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2374    threadgroup_barrier(mem_flags::mem_threadgroup);
2375    if (tid == 0) {
2376        float total = 0.0;
2377        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2378        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2379    }
2380    threadgroup_barrier(mem_flags::mem_threadgroup);
2381    float inv_sum = shared_sum[0];
2382    for (uint s = tid; s < seq_len; s += 256) {
2383        scores[s] *= inv_sum;
2384    }
2385    threadgroup_barrier(mem_flags::mem_threadgroup);
2386
2387    // Step 3: scores * V using float4 vectorized loads
2388    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2389    // This is much better than the scalar version where only 64 of 256 threads are active.
2390    uint v_stride = num_kv_heads * head_dim;
2391    uint head_dim4 = head_dim / 4;
2392    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2393        uint d = d4 * 4;
2394        float4 acc = float4(0.0);
2395        uint v_base = kv_head * head_dim + d;
2396        uint seq_len4 = seq_len & ~3u;
2397        for (uint s = 0; s < seq_len4; s += 4) {
2398            float sc0 = scores[s];
2399            float sc1 = scores[s + 1];
2400            float sc2 = scores[s + 2];
2401            float sc3 = scores[s + 3];
2402            acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
2403                 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
2404                 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
2405                 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
2406        }
2407        for (uint s = seq_len4; s < seq_len; s++) {
2408            acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
2409        }
2410        *(device float4*)(output_batch + out_off + d) = acc;
2411    }
2412    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2413    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2414        float acc = 0.0;
2415        uint v_base = kv_head * head_dim + d;
2416        for (uint s = 0; s < seq_len; s++) {
2417            acc += scores[s] * v_cache[s * v_stride + v_base];
2418        }
2419        output_batch[out_off + d] = acc;
2420    }
2421}
2422
2423// ── attention_flash_batch ─────────────────────────────────────────────
2424// Streaming attention with online softmax.  Same grid as attention_batch
2425// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2426// matrix is never materialized.  K/V positions are processed in a tile of
2427// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2428// the standard flash-attention recurrence:
2429//
2430//     m_new   = max(m_old, tile_max)
2431//     alpha   = exp(m_old - m_new)
2432//     l_new   = alpha * l_old + sum(exp(S - m_new))
2433//     O_new   = alpha * O_old + sum(exp(S - m_new) * V)
2434//     O_final = O / l
2435//
2436// This removes the `scores[2048]` cap in attention_batch (which silently
2437// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2438// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2439//
2440// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2441constant constexpr uint FLASH_K_TILE = 32;
2442constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2443
2444kernel void attention_flash_batch(
2445    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2446    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2447    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2448    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2449    constant uint& M                 [[buffer(4)]],
2450    constant uint& base_pos          [[buffer(5)]],
2451    constant uint& num_heads         [[buffer(6)]],
2452    constant uint& num_kv_heads      [[buffer(7)]],
2453    constant uint& head_dim          [[buffer(8)]],
2454    constant uint& q_stride          [[buffer(9)]],
2455    uint tgid [[threadgroup_position_in_grid]],
2456    uint tid [[thread_index_in_threadgroup]],
2457    uint simd_lane [[thread_index_in_simdgroup]],
2458    uint simd_id [[simdgroup_index_in_threadgroup]])
2459{
2460    uint token_idx = tgid / num_heads;
2461    uint head = tgid % num_heads;
2462    if (token_idx >= M) return;
2463
2464    uint kv_head = head / (num_heads / num_kv_heads);
2465    uint seq_len = base_pos + token_idx + 1;  // causal: attend to [0, base_pos + token_idx]
2466    uint q_off = token_idx * q_stride + head * head_dim;
2467    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2468
2469    // Threadgroup state:
2470    //   q_sh:      Q vector for this (token, head), loaded once
2471    //   o_sh:      running output vector, updated each K tile
2472    //   scores_sh: scores for the current K tile only
2473    //   stats:     [running max, running sum]  (see flash-attention recurrence)
2474    threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2475    threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2476    threadgroup float scores_sh[FLASH_K_TILE];
2477    threadgroup float stats[2];
2478    threadgroup float sg_scratch[8];  // simdgroup-level reduction buffer
2479
2480    // --- Load Q (one row) and zero the running O ---
2481    for (uint d = tid; d < head_dim; d += 256) {
2482        q_sh[d] = q_batch[q_off + d];
2483        o_sh[d] = 0.0f;
2484    }
2485    if (tid == 0) {
2486        stats[0] = -INFINITY;
2487        stats[1] = 0.0f;
2488    }
2489    threadgroup_barrier(mem_flags::mem_threadgroup);
2490
2491    float scale = fast::rsqrt(float(head_dim));
2492    uint v_stride = num_kv_heads * head_dim;
2493    uint v_base = kv_head * head_dim;
2494
2495    // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2496    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2497        uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2498
2499        // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2500        // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2501        for (uint ti = simd_id; ti < tile_n; ti += 8) {
2502            uint k_off = (kv_base + ti) * v_stride + v_base;  // same layout as V stride
2503            float dot = 0.0f;
2504            for (uint d = simd_lane; d < head_dim; d += 32) {
2505                dot += q_sh[d] * k_cache[k_off + d];
2506            }
2507            dot = simd_sum(dot);
2508            if (simd_lane == 0) {
2509                scores_sh[ti] = dot * scale;
2510            }
2511        }
2512        threadgroup_barrier(mem_flags::mem_threadgroup);
2513
2514        // [2] Tile max via cooperative reduction.
2515        float local_max = -INFINITY;
2516        for (uint s = tid; s < tile_n; s += 256) {
2517            local_max = max(local_max, scores_sh[s]);
2518        }
2519        local_max = simd_max(local_max);
2520        if (simd_lane == 0) {
2521            sg_scratch[simd_id] = local_max;
2522        }
2523        threadgroup_barrier(mem_flags::mem_threadgroup);
2524        // [3] Merge with running max, compute alpha, rescale running l.
2525        float m_new;
2526        float alpha;
2527        if (tid == 0) {
2528            float tile_max = sg_scratch[0];
2529            for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2530            float m_old = stats[0];
2531            m_new = max(m_old, tile_max);
2532            // First iteration: m_old = -inf → alpha = 0 (reset O).
2533            alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2534            stats[0] = m_new;
2535            stats[1] *= alpha;
2536            // Broadcast via sg_scratch.
2537            sg_scratch[0] = alpha;
2538            sg_scratch[1] = m_new;
2539        }
2540        threadgroup_barrier(mem_flags::mem_threadgroup);
2541        alpha = sg_scratch[0];
2542        m_new = sg_scratch[1];
2543
2544        // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2545        for (uint d = tid; d < head_dim; d += 256) {
2546            o_sh[d] *= alpha;
2547        }
2548        for (uint s = tid; s < tile_n; s += 256) {
2549            scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2550        }
2551        threadgroup_barrier(mem_flags::mem_threadgroup);
2552
2553        // [5] Tile sum → update running l.
2554        float local_sum = 0.0f;
2555        for (uint s = tid; s < tile_n; s += 256) {
2556            local_sum += scores_sh[s];
2557        }
2558        local_sum = simd_sum(local_sum);
2559        if (simd_lane == 0) {
2560            sg_scratch[simd_id] = local_sum;
2561        }
2562        threadgroup_barrier(mem_flags::mem_threadgroup);
2563        if (tid == 0) {
2564            float tile_sum = 0.0f;
2565            for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2566            stats[1] += tile_sum;
2567        }
2568        threadgroup_barrier(mem_flags::mem_threadgroup);
2569
2570        // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2571        for (uint d = tid; d < head_dim; d += 256) {
2572            float acc = 0.0f;
2573            for (uint s = 0; s < tile_n; s++) {
2574                acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2575            }
2576            o_sh[d] += acc;
2577        }
2578        threadgroup_barrier(mem_flags::mem_threadgroup);
2579    }
2580
2581    // --- Normalize and write output ---
2582    float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2583    for (uint d = tid; d < head_dim; d += 256) {
2584        output_batch[out_off + d] = o_sh[d] * inv_l;
2585    }
2586}
2587
2588// ── rope_qk_batch ─────────────────────────────────────────────────────
2589// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2590// launch + memory barrier per layer. Both Q and K live in the same
2591// qkv_data buffer at different offsets within each token's row.
2592// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2593kernel void rope_qk_batch(
2594    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2595    constant uint& M                 [[buffer(1)]],   // num tokens
2596    constant uint& base_pos          [[buffer(2)]],   // starting position
2597    constant uint& num_q_heads       [[buffer(3)]],
2598    constant uint& num_kv_heads      [[buffer(4)]],
2599    constant uint& head_dim          [[buffer(5)]],
2600    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2601    constant float& theta            [[buffer(7)]],
2602    uint id [[thread_position_in_grid]])
2603{
2604    uint half_dim = head_dim / 2;
2605    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2606    uint token = id / total_pairs;
2607    uint pair = id % total_pairs;
2608    if (token >= M) return;
2609
2610    uint pos = base_pos + token;
2611    uint q_pairs = num_q_heads * half_dim;
2612
2613    uint h, i, offset;
2614    if (pair < q_pairs) {
2615        // Q head
2616        h = pair / half_dim;
2617        i = pair % half_dim;
2618        offset = token * qkv_stride + h * head_dim + i * 2;
2619    } else {
2620        // K head
2621        uint kp = pair - q_pairs;
2622        h = kp / half_dim;
2623        i = kp % half_dim;
2624        uint k_start = num_q_heads * head_dim;
2625        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2626    }
2627
2628    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2629    float angle = float(pos) * freq;
2630    float cos_val = cos(angle);
2631    float sin_val = sin(angle);
2632
2633    float x0 = qkv_data[offset];
2634    float x1 = qkv_data[offset + 1];
2635    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2636    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2637}
2638
2639// ── copy_kv_both_batch ────────────────────────────────────────────────
2640// Fused K+V cache copy in a single dispatch: copies both K and V from
2641// the strided batch QKV buffer to their respective KV cache buffers.
2642// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2643kernel void copy_kv_both_batch(
2644    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride]
2645    device float* k_dst        [[buffer(1)]],  // K cache [max_seq, kv_dim]
2646    device float* v_dst        [[buffer(2)]],  // V cache [max_seq, kv_dim]
2647    constant uint& M           [[buffer(3)]],  // num tokens in batch
2648    constant uint& kv_dim      [[buffer(4)]],  // floats per KV vector
2649    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2650    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2651    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2652    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2653    uint id [[thread_position_in_grid]])
2654{
2655    // Total elements = M * kv_dim * 2 (K + V)
2656    uint total_kv = M * kv_dim;
2657    if (id >= total_kv * 2) return;
2658
2659    uint is_v = id / total_kv;        // 0 = K, 1 = V
2660    uint local_id = id % total_kv;
2661    uint token = local_id / kv_dim;
2662    uint d = local_id % kv_dim;
2663
2664    uint dst_off = (base_pos + token) * kv_dim + d;
2665    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2666
2667    if (is_v) {
2668        v_dst[dst_off] = src[src_off];
2669    } else {
2670        k_dst[dst_off] = src[src_off];
2671    }
2672}
2673"#
2674    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2675    .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2676}
2677
2678// ---------------------------------------------------------------------------
2679// model.rs generation
2680// ---------------------------------------------------------------------------
2681
2682fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2683    let mut code = String::with_capacity(48 * 1024);
2684    emit_model_header(&mut code, config)?;
2685    emit_metal_model_struct(&mut code, config)?;
2686    emit_layer_buffers_struct(&mut code, config)?;
2687    emit_metal_model_impl(&mut code, config)?;
2688    emit_helper_functions(&mut code)?;
2689    Ok(code)
2690}
2691
2692fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2693    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2694    writeln!(
2695        code,
2696        "//! Model: {} ({} layers, hidden={})",
2697        config.architecture, config.num_layers, config.hidden_size
2698    )?;
2699    writeln!(code, "//!")?;
2700    writeln!(
2701        code,
2702        "//! Uses native Metal compute pipelines via the metal crate."
2703    )?;
2704    writeln!(code)?;
2705    writeln!(code, "#![allow(dead_code)]")?;
2706    writeln!(code)?;
2707    writeln!(code, "use metal::*;")?;
2708    writeln!(code, "#[allow(unused_imports)]")?;
2709    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2710    writeln!(code, "use std::mem;")?;
2711    writeln!(code)?;
2712
2713    // Model constants
2714    writeln!(
2715        code,
2716        "// ── Model constants ──────────────────────────────────"
2717    )?;
2718    writeln!(
2719        code,
2720        "pub const HIDDEN_SIZE: usize = {};",
2721        config.hidden_size
2722    )?;
2723    writeln!(
2724        code,
2725        "pub const INTERMEDIATE_SIZE: usize = {};",
2726        config.intermediate_size
2727    )?;
2728    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2729    writeln!(
2730        code,
2731        "pub const NUM_HEADS: usize = {};",
2732        config.num_attention_heads
2733    )?;
2734    writeln!(
2735        code,
2736        "pub const NUM_KV_HEADS: usize = {};",
2737        config.num_kv_heads
2738    )?;
2739    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
2740    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
2741    let effective_seq_len = config.max_seq_len.min(4096);
2742    writeln!(
2743        code,
2744        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
2745        effective_seq_len, config.max_seq_len
2746    )?;
2747    writeln!(
2748        code,
2749        "pub const RMS_NORM_EPS: f32 = {:e};",
2750        config.rms_norm_eps
2751    )?;
2752    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
2753    writeln!(
2754        code,
2755        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
2756    )?;
2757    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
2758    writeln!(code)?;
2759
2760    Ok(())
2761}
2762
2763fn emit_metal_model_struct(
2764    code: &mut String,
2765    config: &ModelConfig,
2766) -> Result<(), MetalCodegenError> {
2767    writeln!(
2768        code,
2769        "// ── MetalModel ──────────────────────────────────────────"
2770    )?;
2771    writeln!(code)?;
2772    writeln!(
2773        code,
2774        "/// Metal-accelerated transformer model for Apple Silicon."
2775    )?;
2776    writeln!(code, "///")?;
2777    writeln!(
2778        code,
2779        "/// Uses unified memory for zero-copy weight access and native Metal"
2780    )?;
2781    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
2782    writeln!(code, "pub struct MetalModel {{")?;
2783    writeln!(code, "    device: Device,")?;
2784    writeln!(code, "    queue: CommandQueue,")?;
2785    writeln!(code)?;
2786    writeln!(code, "    // ── Compute pipelines ──")?;
2787    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
2788    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
2789    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
2790    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
2791    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
2792    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
2793    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
2794    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
2795    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
2796    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
2797    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
2798    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
2799    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
2800    writeln!(code)?;
2801    writeln!(code, "    // ── Batched prefill pipelines ──")?;
2802    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
2803    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
2804    writeln!(
2805        code,
2806        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
2807    )?;
2808    writeln!(code, "    matmul_q8_mma_pipeline: ComputePipelineState,")?;
2809    writeln!(code, "    matmul_q8_mma32_pipeline: ComputePipelineState,")?;
2810    writeln!(
2811        code,
2812        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
2813    )?;
2814    writeln!(
2815        code,
2816        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
2817    )?;
2818    writeln!(
2819        code,
2820        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
2821    )?;
2822    if config.qkv_bias {
2823        writeln!(code, "    add_bias_batch_pipeline: ComputePipelineState,")?;
2824    }
2825    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
2826    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
2827    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
2828    writeln!(
2829        code,
2830        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
2831    )?;
2832    writeln!(
2833        code,
2834        "    add_inplace_batch_pipeline: ComputePipelineState,"
2835    )?;
2836    writeln!(
2837        code,
2838        "    copy_embedding_batch_pipeline: ComputePipelineState,"
2839    )?;
2840    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
2841    writeln!(
2842        code,
2843        "    attention_flash_batch_pipeline: ComputePipelineState,"
2844    )?;
2845    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
2846    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
2847    writeln!(
2848        code,
2849        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
2850    )?;
2851    writeln!(code)?;
2852    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
2853    writeln!(
2854        code,
2855        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
2856    )?;
2857    writeln!(code, "    embed_buf: Buffer,")?;
2858    writeln!(code)?;
2859    writeln!(code, "    /// Per-layer weight buffers")?;
2860    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
2861    writeln!(code)?;
2862    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
2863    writeln!(code, "    norm_buf: Buffer,")?;
2864    writeln!(code)?;
2865    writeln!(
2866        code,
2867        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
2868    )?;
2869    writeln!(code, "    lm_head_buf: Buffer,")?;
2870    writeln!(code)?;
2871    writeln!(
2872        code,
2873        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
2874    )?;
2875    writeln!(code, "    hidden_buf: Buffer,")?;
2876    writeln!(code, "    residual_buf: Buffer,")?;
2877    writeln!(code, "    normed_buf: Buffer,")?;
2878    writeln!(
2879        code,
2880        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
2881    )?;
2882    writeln!(code, "    qkv_buf: Buffer,")?;
2883    writeln!(code, "    attn_out_buf: Buffer,")?;
2884    writeln!(code, "    attn_proj_buf: Buffer,")?;
2885    writeln!(
2886        code,
2887        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
2888    )?;
2889    writeln!(code, "    gate_up_buf: Buffer,")?;
2890    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
2891    writeln!(code, "    ffn_out_buf: Buffer,")?;
2892    writeln!(code, "    add_tmp_buf: Buffer,")?;
2893    writeln!(code, "    logits_buf: Buffer,")?;
2894    writeln!(code)?;
2895    writeln!(code, "    // ── Batched prefill working buffers ──")?;
2896    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
2897    writeln!(code, "    batch_hidden_buf: Buffer,")?;
2898    writeln!(
2899        code,
2900        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
2901    )?;
2902    writeln!(code, "    batch_residual_buf: Buffer,")?;
2903    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
2904    writeln!(code, "    batch_qkv_buf: Buffer,")?;
2905    writeln!(
2906        code,
2907        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
2908    )?;
2909    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
2910    writeln!(
2911        code,
2912        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
2913    )?;
2914    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
2915    writeln!(
2916        code,
2917        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
2918    )?;
2919    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
2920    writeln!(
2921        code,
2922        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
2923    )?;
2924    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
2925    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
2926    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
2927    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
2928    writeln!(code, "    batch_tokens_buf: Buffer,")?;
2929    writeln!(code, "    /// Positions buffer for batch RoPE")?;
2930    writeln!(code, "    batch_positions_buf: Buffer,")?;
2931    writeln!(code)?;
2932    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
2933    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
2934    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
2935    writeln!(code)?;
2936    writeln!(code, "    // ── Inference state ──")?;
2937    writeln!(code, "    pos: usize,")?;
2938    writeln!(code)?;
2939    writeln!(
2940        code,
2941        "    /// Previous command buffer for double-buffered prefill."
2942    )?;
2943    writeln!(
2944        code,
2945        "    /// While the GPU executes token N, the CPU can encode token N+1."
2946    )?;
2947    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
2948    writeln!(code, "}}")?;
2949    writeln!(code)?;
2950
2951    Ok(())
2952}
2953
2954fn emit_layer_buffers_struct(
2955    code: &mut String,
2956    config: &ModelConfig,
2957) -> Result<(), MetalCodegenError> {
2958    writeln!(
2959        code,
2960        "/// Per-layer weight buffers for attention and FFN projections."
2961    )?;
2962    writeln!(code, "struct LayerBuffers {{")?;
2963    writeln!(code, "    attn_norm: Buffer,")?;
2964    writeln!(
2965        code,
2966        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
2967    )?;
2968    writeln!(code, "    qkv_weight: Buffer,")?;
2969    if config.qkv_bias {
2970        writeln!(
2971            code,
2972            "    /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
2973        )?;
2974        writeln!(code, "    qkv_bias: Buffer,")?;
2975    }
2976    writeln!(code, "    o_weight: Buffer,")?;
2977    writeln!(code, "    ffn_norm: Buffer,")?;
2978    writeln!(
2979        code,
2980        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
2981    )?;
2982    writeln!(code, "    gate_up_weight: Buffer,")?;
2983    writeln!(code, "    down_weight: Buffer,")?;
2984    writeln!(code, "}}")?;
2985    writeln!(code)?;
2986
2987    Ok(())
2988}
2989
2990fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2991    let hidden = config.hidden_size;
2992    let intermediate = config.intermediate_size;
2993    let _num_layers = config.num_layers;
2994    let num_heads = config.num_attention_heads;
2995    let num_kv_heads = config.num_kv_heads;
2996    let head_dim = config.head_dim;
2997    let vocab = config.vocab_size;
2998    let effective_seq_len = config.max_seq_len.min(4096);
2999    let is_q8 = config.dtype == DType::Q8_0;
3000    let is_q4 = config.dtype == DType::Q4_0;
3001    let kv_dim = num_kv_heads * head_dim;
3002
3003    writeln!(code, "impl MetalModel {{")?;
3004
3005    // ── new() ──
3006    writeln!(
3007        code,
3008        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3009    )?;
3010    writeln!(code, "    ///")?;
3011    writeln!(
3012        code,
3013        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
3014    )?;
3015    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
3016    writeln!(
3017        code,
3018        "        let device = Device::system_default().expect(\"no Metal device found\");"
3019    )?;
3020    writeln!(code, "        let queue = device.new_command_queue();")?;
3021    writeln!(code)?;
3022
3023    // Compile shaders
3024    writeln!(
3025        code,
3026        "        // Compile Metal shaders from embedded source"
3027    )?;
3028    writeln!(
3029        code,
3030        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
3031    )?;
3032    writeln!(
3033        code,
3034        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3035    )?;
3036    writeln!(
3037        code,
3038        "            .expect(\"failed to compile Metal shaders\");"
3039    )?;
3040    writeln!(code)?;
3041
3042    // Create compute pipelines
3043    writeln!(code, "        // Create compute pipelines")?;
3044    for (var, fn_name) in [
3045        ("matmul_pipeline", "matmul_vec"),
3046        ("matmul_q8_pipeline", "matmul_vec_q8"),
3047        ("matmul_q4_pipeline", "matmul_vec_q4"),
3048        ("rms_norm_pipeline", "rms_norm"),
3049        ("rope_pipeline", "rope"),
3050        ("softmax_pipeline", "softmax"),
3051        ("silu_mul_pipeline", "silu_mul"),
3052        ("silu_mul_fused_pipeline", "silu_mul_fused"),
3053        ("add_pipeline", "elementwise_add"),
3054        ("attention_pipeline", "attention"),
3055        ("add_inplace_pipeline", "add_inplace"),
3056        ("copy_pipeline", "copy_buffer"),
3057        ("copy_offset_pipeline", "copy_offset"),
3058        ("matmul_batch_pipeline", "matmul_vec_batch"),
3059        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3060        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3061        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3062        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3063        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3064        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3065        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3066        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3067        ("rms_norm_batch_pipeline", "rms_norm_batch"),
3068        ("rope_batch_pipeline", "rope_batch"),
3069        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3070        ("add_inplace_batch_pipeline", "add_inplace_batch"),
3071        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3072        ("attention_batch_pipeline", "attention_batch"),
3073        ("attention_flash_batch_pipeline", "attention_flash_batch"),
3074        ("copy_kv_batch_pipeline", "copy_kv_batch"),
3075        ("rope_qk_batch_pipeline", "rope_qk_batch"),
3076        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3077    ] {
3078        writeln!(
3079            code,
3080            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3081        )?;
3082    }
3083    if config.qkv_bias {
3084        writeln!(
3085            code,
3086            "        let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3087        )?;
3088    }
3089    writeln!(code)?;
3090
3091    // Weight loading
3092    writeln!(
3093        code,
3094        "        // Load weights into Metal shared-memory buffers"
3095    )?;
3096    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
3097    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3098    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
3099    writeln!(code)?;
3100    writeln!(
3101        code,
3102        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
3103    )?;
3104    writeln!(code)?;
3105    writeln!(
3106        code,
3107        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3108    )?;
3109    writeln!(
3110        code,
3111        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3112    )?;
3113    writeln!(code, "            let byte_len = n * f32_size;")?;
3114    writeln!(code, "            let cur = cursor.get();")?;
3115    writeln!(
3116        code,
3117        "            let data = &weights[cur..cur + byte_len];"
3118    )?;
3119    writeln!(code, "            cursor.set(cur + byte_len);")?;
3120    writeln!(code, "            device.new_buffer_with_data(")?;
3121    writeln!(code, "                data.as_ptr() as *const _,")?;
3122    writeln!(code, "                byte_len as u64,")?;
3123    writeln!(
3124        code,
3125        "                MTLResourceOptions::StorageModeShared,"
3126    )?;
3127    writeln!(code, "            )")?;
3128    writeln!(code, "        }};")?;
3129    writeln!(code)?;
3130
3131    if is_q8 {
3132        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
3133        // We load them directly into Metal buffers without dequantizing,
3134        // and use the matmul_vec_q8 shader that operates on quantized data.
3135        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
3136        writeln!(
3137            code,
3138            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3139        )?;
3140        writeln!(
3141            code,
3142            "        // as raw bytes into a Metal buffer (no dequantization)."
3143        )?;
3144        writeln!(
3145            code,
3146            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3147        )?;
3148        writeln!(
3149            code,
3150            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3151        )?;
3152        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3153        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3154        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3155        writeln!(code, "            let cur = cursor.get();")?;
3156        writeln!(
3157            code,
3158            "            let data = &weights[cur..cur + total_raw];"
3159        )?;
3160        writeln!(code, "            cursor.set(cur + total_raw);")?;
3161        writeln!(code, "            device.new_buffer_with_data(")?;
3162        writeln!(code, "                data.as_ptr() as *const _,")?;
3163        writeln!(code, "                total_raw as u64,")?;
3164        writeln!(
3165            code,
3166            "                MTLResourceOptions::StorageModeShared,"
3167        )?;
3168        writeln!(code, "            )")?;
3169        writeln!(code, "        }};")?;
3170        writeln!(code)?;
3171        writeln!(
3172            code,
3173            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3174        )?;
3175        writeln!(
3176            code,
3177            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3178        )?;
3179        writeln!(
3180            code,
3181            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3182        )?;
3183        writeln!(
3184            code,
3185            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3186        )?;
3187        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3188        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3189        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3190        writeln!(code, "            let cur = cursor.get();")?;
3191        writeln!(
3192            code,
3193            "            let data = &weights[cur..cur + total_raw];"
3194        )?;
3195        writeln!(code, "            cursor.set(cur + total_raw);")?;
3196        writeln!(code, "            device.new_buffer_with_data(")?;
3197        writeln!(code, "                data.as_ptr() as *const _,")?;
3198        writeln!(code, "                total_raw as u64,")?;
3199        writeln!(
3200            code,
3201            "                MTLResourceOptions::StorageModeShared,"
3202        )?;
3203        writeln!(code, "            )")?;
3204        writeln!(code, "        }};")?;
3205        writeln!(code)?;
3206    }
3207
3208    if is_q4 {
3209        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3210        // We load them directly into Metal buffers without dequantizing,
3211        // and use the matmul_vec_q4 shader that operates on quantized data.
3212        // This quarters GPU memory usage vs f32 dequantization.
3213        writeln!(
3214            code,
3215            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3216        )?;
3217        writeln!(
3218            code,
3219            "        // as raw bytes into a Metal buffer (no dequantization)."
3220        )?;
3221        writeln!(
3222            code,
3223            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3224        )?;
3225        writeln!(
3226            code,
3227            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3228        )?;
3229        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3230        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3231        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3232        writeln!(code, "            let cur = cursor.get();")?;
3233        writeln!(
3234            code,
3235            "            let data = &weights[cur..cur + total_raw];"
3236        )?;
3237        writeln!(code, "            cursor.set(cur + total_raw);")?;
3238        writeln!(code, "            device.new_buffer_with_data(")?;
3239        writeln!(code, "                data.as_ptr() as *const _,")?;
3240        writeln!(code, "                total_raw as u64,")?;
3241        writeln!(
3242            code,
3243            "                MTLResourceOptions::StorageModeShared,"
3244        )?;
3245        writeln!(code, "            )")?;
3246        writeln!(code, "        }};")?;
3247        writeln!(code)?;
3248        writeln!(
3249            code,
3250            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3251        )?;
3252        writeln!(
3253            code,
3254            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3255        )?;
3256        writeln!(
3257            code,
3258            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3259        )?;
3260        writeln!(
3261            code,
3262            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3263        )?;
3264        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3265        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3266        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3267        writeln!(code, "            let cur = cursor.get();")?;
3268        writeln!(
3269            code,
3270            "            let data = &weights[cur..cur + total_raw];"
3271        )?;
3272        writeln!(code, "            cursor.set(cur + total_raw);")?;
3273        writeln!(code, "            device.new_buffer_with_data(")?;
3274        writeln!(code, "                data.as_ptr() as *const _,")?;
3275        writeln!(code, "                total_raw as u64,")?;
3276        writeln!(
3277            code,
3278            "                MTLResourceOptions::StorageModeShared,"
3279        )?;
3280        writeln!(code, "            )")?;
3281        writeln!(code, "        }};")?;
3282        writeln!(code)?;
3283    }
3284
3285    writeln!(
3286        code,
3287        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3288    )?;
3289    writeln!(code)?;
3290
3291    // Per-layer weights
3292    writeln!(
3293        code,
3294        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3295    )?;
3296    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3297
3298    // attn_norm is always f32
3299    writeln!(
3300        code,
3301        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3302    )?;
3303
3304    let qkv_rows = hidden + 2 * kv_dim;
3305    if is_q8 {
3306        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3307        writeln!(
3308            code,
3309            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3310        )?;
3311        if config.qkv_bias {
3312            writeln!(
3313                code,
3314                "            // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3315            )?;
3316            writeln!(
3317                code,
3318                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3319            )?;
3320        }
3321        writeln!(
3322            code,
3323            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3324        )?;
3325    } else if is_q4 {
3326        writeln!(
3327            code,
3328            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3329        )?;
3330        if config.qkv_bias {
3331            writeln!(
3332                code,
3333                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3334            )?;
3335        }
3336        writeln!(
3337            code,
3338            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3339        )?;
3340    } else {
3341        writeln!(
3342            code,
3343            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3344        )?;
3345        if config.qkv_bias {
3346            writeln!(
3347                code,
3348                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3349            )?;
3350        }
3351        writeln!(
3352            code,
3353            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3354        )?;
3355    }
3356
3357    // ffn_norm is always f32
3358    writeln!(
3359        code,
3360        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3361    )?;
3362
3363    let gate_up_rows = 2 * intermediate;
3364    if is_q8 {
3365        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3366        writeln!(
3367            code,
3368            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3369        )?;
3370        writeln!(
3371            code,
3372            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3373        )?;
3374    } else if is_q4 {
3375        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3376        writeln!(
3377            code,
3378            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3379        )?;
3380        writeln!(
3381            code,
3382            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3383        )?;
3384    } else {
3385        // Fused gate+up weight: read both as a single contiguous f32 buffer
3386        writeln!(
3387            code,
3388            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3389        )?;
3390        writeln!(
3391            code,
3392            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3393        )?;
3394    }
3395
3396    writeln!(code, "            layers.push(LayerBuffers {{")?;
3397    writeln!(code, "                attn_norm,")?;
3398    writeln!(code, "                qkv_weight,")?;
3399    if config.qkv_bias {
3400        writeln!(code, "                qkv_bias,")?;
3401    }
3402    writeln!(code, "                o_weight,")?;
3403    writeln!(code, "                ffn_norm,")?;
3404    writeln!(code, "                gate_up_weight,")?;
3405    writeln!(code, "                down_weight,")?;
3406    writeln!(code, "            }});")?;
3407    writeln!(code, "        }}")?;
3408    writeln!(code)?;
3409
3410    // final_norm is always f32
3411    writeln!(
3412        code,
3413        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3414    )?;
3415    writeln!(code)?;
3416
3417    // lm_head
3418    if is_q8 {
3419        writeln!(
3420            code,
3421            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3422        )?;
3423    } else if is_q4 {
3424        writeln!(
3425            code,
3426            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3427        )?;
3428    } else {
3429        writeln!(
3430            code,
3431            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3432        )?;
3433    }
3434    writeln!(code)?;
3435
3436    // Working buffers
3437    let hidden_bytes = hidden * 4;
3438    let _kv_dim_bytes = kv_dim * 4;
3439    let intermediate_bytes = intermediate * 4;
3440    let vocab_bytes = vocab * 4;
3441    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
3442
3443    writeln!(
3444        code,
3445        "        // Allocate working buffers (shared memory for zero-copy)"
3446    )?;
3447    writeln!(
3448        code,
3449        "        let opts = MTLResourceOptions::StorageModeShared;"
3450    )?;
3451    writeln!(
3452        code,
3453        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3454    )?;
3455    writeln!(
3456        code,
3457        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3458    )?;
3459    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3460    writeln!(
3461        code,
3462        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3463    )?;
3464    writeln!(
3465        code,
3466        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3467    )?;
3468    writeln!(
3469        code,
3470        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3471    )?;
3472    writeln!(
3473        code,
3474        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3475    )?;
3476    writeln!(
3477        code,
3478        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3479    )?;
3480    let gate_up_buf_bytes = 2 * intermediate * 4;
3481    writeln!(
3482        code,
3483        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3484    )?;
3485    writeln!(
3486        code,
3487        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3488    )?;
3489    writeln!(
3490        code,
3491        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3492    )?;
3493    writeln!(
3494        code,
3495        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3496    )?;
3497    writeln!(
3498        code,
3499        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3500    )?;
3501    writeln!(
3502        code,
3503        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3504    )?;
3505    writeln!(code)?;
3506
3507    // Batch prefill working buffers
3508    let batch_hidden_bytes = hidden * 4; // per-token
3509    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3510    let batch_gate_up_bytes = 2 * intermediate * 4;
3511    let batch_intermediate_bytes = intermediate * 4;
3512    writeln!(
3513        code,
3514        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3515    )?;
3516    writeln!(
3517        code,
3518        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3519    )?;
3520    writeln!(
3521        code,
3522        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3523    )?;
3524    writeln!(
3525        code,
3526        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3527    )?;
3528    writeln!(
3529        code,
3530        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3531    )?;
3532    writeln!(
3533        code,
3534        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3535    )?;
3536    writeln!(
3537        code,
3538        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3539    )?;
3540    writeln!(
3541        code,
3542        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3543    )?;
3544    writeln!(
3545        code,
3546        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3547    )?;
3548    writeln!(
3549        code,
3550        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3551    )?;
3552    writeln!(
3553        code,
3554        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3555    )?;
3556    writeln!(code)?;
3557
3558    // KV cache buffers
3559    writeln!(code, "        // KV cache buffers (per-layer)")?;
3560    writeln!(
3561        code,
3562        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3563    )?;
3564    writeln!(
3565        code,
3566        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3567    )?;
3568    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3569    writeln!(
3570        code,
3571        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3572    )?;
3573    writeln!(
3574        code,
3575        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3576    )?;
3577    writeln!(code, "        }}")?;
3578    writeln!(code)?;
3579
3580    writeln!(code, "        Self {{")?;
3581    writeln!(code, "            device,")?;
3582    writeln!(code, "            queue,")?;
3583    writeln!(code, "            matmul_pipeline,")?;
3584    writeln!(code, "            matmul_q8_pipeline,")?;
3585    writeln!(code, "            matmul_q4_pipeline,")?;
3586    writeln!(code, "            rms_norm_pipeline,")?;
3587    writeln!(code, "            rope_pipeline,")?;
3588    writeln!(code, "            softmax_pipeline,")?;
3589    writeln!(code, "            silu_mul_pipeline,")?;
3590    writeln!(code, "            silu_mul_fused_pipeline,")?;
3591    writeln!(code, "            add_pipeline,")?;
3592    writeln!(code, "            attention_pipeline,")?;
3593    writeln!(code, "            add_inplace_pipeline,")?;
3594    writeln!(code, "            copy_pipeline,")?;
3595    writeln!(code, "            copy_offset_pipeline,")?;
3596    writeln!(code, "            matmul_batch_pipeline,")?;
3597    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3598    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3599    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3600    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3601    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3602    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3603    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3604    if config.qkv_bias {
3605        writeln!(code, "            add_bias_batch_pipeline,")?;
3606    }
3607    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3608    writeln!(code, "            rms_norm_batch_pipeline,")?;
3609    writeln!(code, "            rope_batch_pipeline,")?;
3610    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3611    writeln!(code, "            add_inplace_batch_pipeline,")?;
3612    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3613    writeln!(code, "            attention_batch_pipeline,")?;
3614    writeln!(code, "            attention_flash_batch_pipeline,")?;
3615    writeln!(code, "            copy_kv_batch_pipeline,")?;
3616    writeln!(code, "            rope_qk_batch_pipeline,")?;
3617    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3618    writeln!(code, "            embed_buf,")?;
3619    writeln!(code, "            layers,")?;
3620    writeln!(code, "            norm_buf,")?;
3621    writeln!(code, "            lm_head_buf,")?;
3622    writeln!(code, "            hidden_buf,")?;
3623    writeln!(code, "            residual_buf,")?;
3624    writeln!(code, "            normed_buf,")?;
3625    writeln!(code, "            qkv_buf,")?;
3626    writeln!(code, "            attn_out_buf,")?;
3627    writeln!(code, "            attn_proj_buf,")?;
3628    writeln!(code, "            gate_up_buf,")?;
3629    writeln!(code, "            ffn_hidden_buf,")?;
3630    writeln!(code, "            ffn_out_buf,")?;
3631    writeln!(code, "            add_tmp_buf,")?;
3632    writeln!(code, "            logits_buf,")?;
3633    writeln!(code, "            batch_hidden_buf,")?;
3634    writeln!(code, "            batch_residual_buf,")?;
3635    writeln!(code, "            batch_qkv_buf,")?;
3636    writeln!(code, "            batch_attn_out_buf,")?;
3637    writeln!(code, "            batch_attn_proj_buf,")?;
3638    writeln!(code, "            batch_gate_up_buf,")?;
3639    writeln!(code, "            batch_ffn_hidden_buf,")?;
3640    writeln!(code, "            batch_ffn_out_buf,")?;
3641    writeln!(code, "            batch_tokens_buf,")?;
3642    writeln!(code, "            batch_positions_buf,")?;
3643    writeln!(code, "            k_cache,")?;
3644    writeln!(code, "            v_cache,")?;
3645    writeln!(code, "            pos: 0,")?;
3646    writeln!(code, "            prev_cmd: None,")?;
3647    writeln!(code, "        }}")?;
3648    writeln!(code, "    }}")?;
3649    writeln!(code)?;
3650
3651    // ── forward() ──
3652    writeln!(
3653        code,
3654        "    /// Run the forward pass for a single token at the current position."
3655    )?;
3656    writeln!(code, "    ///")?;
3657    writeln!(
3658        code,
3659        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3660    )?;
3661    writeln!(code, "    ///")?;
3662    writeln!(
3663        code,
3664        "    /// All GPU operations are encoded into a single command buffer and"
3665    )?;
3666    writeln!(
3667        code,
3668        "    /// committed once at the end, avoiding per-operation synchronization."
3669    )?;
3670    writeln!(
3671        code,
3672        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3673    )?;
3674    writeln!(
3675        code,
3676        "        // Wait for any pending prefill command buffer"
3677    )?;
3678    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3679    writeln!(code, "            prev.wait_until_completed();")?;
3680    writeln!(code, "        }}")?;
3681    writeln!(code)?;
3682    writeln!(code, "        let pos = self.pos;")?;
3683    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3684    writeln!(code)?;
3685
3686    // Single compute encoder for the entire forward pass — no blit encoder
3687    // transitions. Copy operations use compute copy kernels instead of blits.
3688    let matmul_fn = if is_q8 {
3689        "dispatch_matmul_q8"
3690    } else if is_q4 {
3691        "dispatch_matmul_q4"
3692    } else {
3693        "dispatch_matmul"
3694    };
3695
3696    writeln!(
3697        code,
3698        "        // Single compute encoder for the entire forward pass (no blit transitions)"
3699    )?;
3700    writeln!(code, "        {{")?;
3701    writeln!(
3702        code,
3703        "            let enc = cmd.new_compute_command_encoder();"
3704    )?;
3705    writeln!(code)?;
3706
3707    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
3708    writeln!(
3709        code,
3710        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3711    )?;
3712    writeln!(
3713        code,
3714        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3715    )?;
3716    writeln!(
3717        code,
3718        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3719    )?;
3720    writeln!(
3721        code,
3722        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
3723        hidden * 4,
3724    )?;
3725    writeln!(code, "            unsafe {{")?;
3726    writeln!(
3727        code,
3728        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
3729    )?;
3730    writeln!(
3731        code,
3732        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3733    )?;
3734    writeln!(
3735        code,
3736        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
3737    )?;
3738    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
3739    writeln!(
3740        code,
3741        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3742    )?;
3743    writeln!(code, "                    hidden_ptr,")?;
3744    writeln!(code, "                    HIDDEN_SIZE,")?;
3745    writeln!(code, "                );")?;
3746    writeln!(
3747        code,
3748        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3749    )?;
3750    writeln!(code, "            }}")?;
3751    writeln!(code)?;
3752
3753    // 2. Transformer layers
3754    writeln!(code, "            // 2. Transformer layers")?;
3755    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3756    writeln!(code)?;
3757    let q_byte_offset = 0usize;
3758    let k_byte_offset = hidden * 4;
3759    let v_byte_offset = (hidden + kv_dim) * 4;
3760
3761    writeln!(
3762        code,
3763        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
3764    )?;
3765    writeln!(
3766        code,
3767        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3768    )?;
3769    writeln!(
3770        code,
3771        "                // Fused Q+K+V matmul: single dispatch for all three projections"
3772    )?;
3773    writeln!(
3774        code,
3775        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3776    )?;
3777    if config.qkv_bias {
3778        writeln!(
3779            code,
3780            "                // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
3781        )?;
3782        writeln!(
3783            code,
3784            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
3785        )?;
3786    }
3787    writeln!(
3788        code,
3789        "                // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
3790    )?;
3791    writeln!(
3792        code,
3793        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3794    )?;
3795    writeln!(
3796        code,
3797        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3798    )?;
3799    writeln!(code)?;
3800    writeln!(
3801        code,
3802        "                // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
3803    )?;
3804    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
3805    writeln!(
3806        code,
3807        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3808    )?;
3809    writeln!(
3810        code,
3811        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3812    )?;
3813    writeln!(code)?;
3814    writeln!(
3815        code,
3816        "                // Attention using Q from qkv_buf (offset 0)"
3817    )?;
3818    writeln!(
3819        code,
3820        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3821    )?;
3822    writeln!(
3823        code,
3824        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3825    )?;
3826    writeln!(
3827        code,
3828        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3829    )?;
3830    writeln!(
3831        code,
3832        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
3833    )?;
3834    writeln!(
3835        code,
3836        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3837    )?;
3838    writeln!(
3839        code,
3840        "                // Fused gate+up matmul: single dispatch for both projections"
3841    )?;
3842    writeln!(
3843        code,
3844        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3845    )?;
3846    writeln!(
3847        code,
3848        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3849    )?;
3850    writeln!(
3851        code,
3852        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3853    )?;
3854    writeln!(
3855        code,
3856        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3857    )?;
3858    writeln!(code, "            }}")?;
3859    writeln!(code)?;
3860
3861    // 3. Final RMS norm + logits
3862    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
3863    writeln!(
3864        code,
3865        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
3866    )?;
3867    writeln!(
3868        code,
3869        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3870    )?;
3871    writeln!(code)?;
3872    writeln!(code, "            enc.end_encoding();")?;
3873    writeln!(code, "        }}")?;
3874    writeln!(code)?;
3875
3876    // 5. Single commit + wait, then read back logits
3877    writeln!(
3878        code,
3879        "        // 5. Commit all GPU work and wait for completion"
3880    )?;
3881    writeln!(code, "        cmd.commit();")?;
3882    writeln!(code, "        cmd.wait_until_completed();")?;
3883    writeln!(code)?;
3884    writeln!(code, "        // 6. Read back logits from GPU")?;
3885    writeln!(code, "        let logits = unsafe {{")?;
3886    writeln!(
3887        code,
3888        "            let ptr = self.logits_buf.contents() as *const f32;"
3889    )?;
3890    writeln!(
3891        code,
3892        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3893    )?;
3894    writeln!(code, "        }};")?;
3895    writeln!(code)?;
3896    writeln!(code, "        self.pos += 1;")?;
3897    writeln!(code, "        logits")?;
3898    writeln!(code, "    }}")?;
3899    writeln!(code)?;
3900
3901    // ── forward_profile: instrumented forward with per-operation timing ──
3902    writeln!(
3903        code,
3904        "    /// Profiling forward pass that prints per-stage GPU timing."
3905    )?;
3906    writeln!(code, "    ///")?;
3907    writeln!(
3908        code,
3909        "    /// Each stage is committed and waited on separately so that GPU timestamps"
3910    )?;
3911    writeln!(
3912        code,
3913        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
3914    )?;
3915    writeln!(
3916        code,
3917        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
3918    )?;
3919    writeln!(
3920        code,
3921        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
3922    )?;
3923    writeln!(code, "        use std::time::Instant;")?;
3924    writeln!(code)?;
3925    writeln!(
3926        code,
3927        "        // Wait for any pending prefill command buffer"
3928    )?;
3929    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3930    writeln!(code, "            prev.wait_until_completed();")?;
3931    writeln!(code, "        }}")?;
3932    writeln!(code)?;
3933    writeln!(code, "        let pos = self.pos;")?;
3934    writeln!(code)?;
3935
3936    // Stage: embedding (CPU, no GPU)
3937    writeln!(
3938        code,
3939        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
3940    )?;
3941    writeln!(code, "        let t_embed = Instant::now();")?;
3942    writeln!(code, "        unsafe {{")?;
3943    writeln!(
3944        code,
3945        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
3946    )?;
3947    writeln!(
3948        code,
3949        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3950    )?;
3951    writeln!(
3952        code,
3953        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
3954    )?;
3955    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
3956    writeln!(
3957        code,
3958        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3959    )?;
3960    writeln!(code, "                hidden_ptr,")?;
3961    writeln!(code, "                HIDDEN_SIZE,")?;
3962    writeln!(code, "            );")?;
3963    writeln!(
3964        code,
3965        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3966    )?;
3967    writeln!(code, "        }}")?;
3968    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
3969    writeln!(code)?;
3970
3971    // Stage: Transformer layers (all together on GPU)
3972    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
3973    writeln!(code, "        let t_layers = Instant::now();")?;
3974    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3975    writeln!(code, "        {{")?;
3976    writeln!(
3977        code,
3978        "            let enc = cmd.new_compute_command_encoder();"
3979    )?;
3980    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3981    writeln!(
3982        code,
3983        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3984    )?;
3985    writeln!(
3986        code,
3987        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3988    )?;
3989    if config.qkv_bias {
3990        writeln!(
3991            code,
3992            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
3993        )?;
3994    }
3995    writeln!(
3996        code,
3997        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3998    )?;
3999    writeln!(
4000        code,
4001        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4002    )?;
4003    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
4004    writeln!(
4005        code,
4006        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4007    )?;
4008    writeln!(
4009        code,
4010        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4011    )?;
4012    writeln!(
4013        code,
4014        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4015    )?;
4016    writeln!(
4017        code,
4018        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4019    )?;
4020    writeln!(
4021        code,
4022        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4023    )?;
4024    writeln!(
4025        code,
4026        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4027    )?;
4028    writeln!(
4029        code,
4030        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4031    )?;
4032    writeln!(
4033        code,
4034        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4035    )?;
4036    writeln!(
4037        code,
4038        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4039    )?;
4040    writeln!(
4041        code,
4042        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4043    )?;
4044    writeln!(code, "            }}")?;
4045    writeln!(code, "            enc.end_encoding();")?;
4046    writeln!(code, "        }}")?;
4047    writeln!(code, "        cmd.commit();")?;
4048    writeln!(code, "        cmd.wait_until_completed();")?;
4049    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
4050    writeln!(code)?;
4051
4052    // Stage: Final norm + logits
4053    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
4054    writeln!(code, "        let t_logits = Instant::now();")?;
4055    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
4056    writeln!(code, "        {{")?;
4057    writeln!(
4058        code,
4059        "            let enc = cmd2.new_compute_command_encoder();"
4060    )?;
4061    writeln!(
4062        code,
4063        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4064    )?;
4065    writeln!(
4066        code,
4067        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4068    )?;
4069    writeln!(code, "            enc.end_encoding();")?;
4070    writeln!(code, "        }}")?;
4071    writeln!(code, "        cmd2.commit();")?;
4072    writeln!(code, "        cmd2.wait_until_completed();")?;
4073    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
4074    writeln!(code)?;
4075
4076    // Print profile results
4077    writeln!(
4078        code,
4079        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4080    )?;
4081    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
4082    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
4083    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
4084    writeln!(
4085        code,
4086        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4087    )?;
4088    writeln!(code)?;
4089
4090    // Read back logits
4091    writeln!(code, "        let logits = unsafe {{")?;
4092    writeln!(
4093        code,
4094        "            let ptr = self.logits_buf.contents() as *const f32;"
4095    )?;
4096    writeln!(
4097        code,
4098        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4099    )?;
4100    writeln!(code, "        }};")?;
4101    writeln!(code)?;
4102    writeln!(code, "        self.pos += 1;")?;
4103    writeln!(code, "        logits")?;
4104    writeln!(code, "    }}")?;
4105    writeln!(code)?;
4106
4107    // ── forward_prefill: single-token async forward (backward compat) ──
4108    writeln!(
4109        code,
4110        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
4111    )?;
4112    writeln!(code, "    ///")?;
4113    writeln!(
4114        code,
4115        "    /// Commits the command buffer without waiting, enabling double-buffered"
4116    )?;
4117    writeln!(
4118        code,
4119        "    /// execution: GPU processes token N while CPU encodes token N+1."
4120    )?;
4121    writeln!(
4122        code,
4123        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
4124    )?;
4125    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
4126    writeln!(code, "    }}")?;
4127    writeln!(code)?;
4128
4129    // ── forward_prefill_batch: batched prefill for multiple tokens ──
4130    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
4131    let batch_matmul_fn = if is_q8 {
4132        "dispatch_matmul_q8_batch"
4133    } else if is_q4 {
4134        "dispatch_matmul_q4_batch"
4135    } else {
4136        "dispatch_matmul_batch"
4137    };
4138
4139    writeln!(
4140        code,
4141        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4142    )?;
4143    writeln!(code, "    ///")?;
4144    writeln!(
4145        code,
4146        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4147    )?;
4148    writeln!(
4149        code,
4150        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
4151    )?;
4152    writeln!(
4153        code,
4154        "    /// This provides significant speedup during prompt prefill."
4155    )?;
4156    writeln!(
4157        code,
4158        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4159    )?;
4160    writeln!(code, "        if tokens.is_empty() {{ return; }}")?;
4161    writeln!(
4162        code,
4163        "        // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4164    )?;
4165    writeln!(
4166        code,
4167        "        // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4168    )?;
4169    writeln!(
4170        code,
4171        "        // longer than that must be processed iteratively.  The KV cache"
4172    )?;
4173    writeln!(code, "        // carries state across chunks via self.pos.")?;
4174    writeln!(
4175        code,
4176        "        for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4177    )?;
4178    writeln!(code, "        let m = chunk.len();")?;
4179    writeln!(code, "        if m == 0 {{ continue; }}")?;
4180    writeln!(code, "        let start_pos = self.pos;")?;
4181    writeln!(code)?;
4182    writeln!(code, "        // Wait for any pending command buffer")?;
4183    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4184    writeln!(code, "            prev.wait_until_completed();")?;
4185    writeln!(code, "        }}")?;
4186    writeln!(code)?;
4187
4188    // Upload token IDs and positions to GPU
4189    writeln!(
4190        code,
4191        "        // Upload token IDs and positions to GPU buffers"
4192    )?;
4193    writeln!(code, "        unsafe {{")?;
4194    writeln!(
4195        code,
4196        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4197    )?;
4198    writeln!(
4199        code,
4200        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4201    )?;
4202    writeln!(code, "            for i in 0..m {{")?;
4203    writeln!(code, "                *tok_ptr.add(i) = chunk[i];")?;
4204    writeln!(
4205        code,
4206        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
4207    )?;
4208    writeln!(code, "            }}")?;
4209    writeln!(code, "        }}")?;
4210    writeln!(code)?;
4211
4212    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4213    writeln!(code, "        {{")?;
4214    writeln!(
4215        code,
4216        "            let enc = cmd.new_compute_command_encoder();"
4217    )?;
4218    writeln!(code)?;
4219
4220    // 1. Batch embedding lookup
4221    writeln!(
4222        code,
4223        "            // 1. Batch embedding lookup: copy all token embeddings at once"
4224    )?;
4225    writeln!(
4226        code,
4227        "            self.dispatch_copy_embedding_batch(&enc, m);"
4228    )?;
4229    // Copy batch_hidden -> batch_residual
4230    writeln!(
4231        code,
4232        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4233    )?;
4234    writeln!(code)?;
4235
4236    // 2. Transformer layers
4237    writeln!(code, "            // 2. Transformer layers")?;
4238    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4239    writeln!(code)?;
4240
4241    // Batch RMS norm: residual -> hidden (batched)
4242    writeln!(
4243        code,
4244        "                // Batch RMS norm: batch_residual -> batch_hidden"
4245    )?;
4246    writeln!(
4247        code,
4248        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4249    )?;
4250
4251    // Batch QKV matmul
4252    writeln!(
4253        code,
4254        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4255    )?;
4256    writeln!(
4257        code,
4258        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4259    )?;
4260    if config.qkv_bias {
4261        writeln!(
4262            code,
4263            "                // Qwen2: broadcast-add QKV bias across all M tokens."
4264        )?;
4265        writeln!(
4266            code,
4267            "                self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4268        )?;
4269    }
4270    writeln!(code)?;
4271
4272    // Fused RoPE on Q+K portions in a single dispatch
4273    let k_float_offset = hidden;
4274    writeln!(
4275        code,
4276        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4277    )?;
4278    writeln!(
4279        code,
4280        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4281    )?;
4282    writeln!(code)?;
4283
4284    // Fused KV cache update: copy both K and V in a single dispatch
4285    let v_float_offset = hidden + kv_dim;
4286    writeln!(
4287        code,
4288        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4289    )?;
4290    writeln!(
4291        code,
4292        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4293    )?;
4294    writeln!(code)?;
4295
4296    // Batched causal attention: ONE dispatch for all M tokens
4297    writeln!(
4298        code,
4299        "                // Batched causal attention: one dispatch for all M tokens"
4300    )?;
4301    writeln!(
4302        code,
4303        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4304    )?;
4305    writeln!(code)?;
4306
4307    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4308    writeln!(code, "                // Batched O projection")?;
4309    writeln!(
4310        code,
4311        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4312    )?;
4313    writeln!(code)?;
4314
4315    // Batch add: residual += attn_proj for all tokens
4316    writeln!(
4317        code,
4318        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4319    )?;
4320    writeln!(code)?;
4321
4322    // Batch FFN
4323    writeln!(
4324        code,
4325        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4326    )?;
4327    writeln!(
4328        code,
4329        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4330    )?;
4331    writeln!(
4332        code,
4333        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4334    )?;
4335    writeln!(
4336        code,
4337        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4338    )?;
4339    writeln!(
4340        code,
4341        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4342    )?;
4343    writeln!(
4344        code,
4345        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4346    )?;
4347    writeln!(code, "            }}")?;
4348    writeln!(code)?;
4349
4350    // Copy last token's residual to single-token residual_buf for next forward() call
4351    writeln!(
4352        code,
4353        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4354    )?;
4355    writeln!(
4356        code,
4357        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4358    )?;
4359    writeln!(code)?;
4360    writeln!(code, "            enc.end_encoding();")?;
4361    writeln!(code, "        }}")?;
4362    writeln!(code)?;
4363
4364    writeln!(code, "        cmd.commit();")?;
4365    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4366    writeln!(code, "        self.pos += m;")?;
4367    writeln!(code, "        }}  // end for chunk")?;
4368    writeln!(code, "    }}")?;
4369    writeln!(code)?;
4370
4371    // ── reset() — rewind KV cache position for new inference requests ──
4372    writeln!(
4373        code,
4374        "    /// Reset the model state for a new inference request."
4375    )?;
4376    writeln!(code, "    pub fn reset(&mut self) {{")?;
4377    writeln!(code, "        self.pos = 0;")?;
4378    writeln!(code, "        self.prev_cmd = None;")?;
4379    writeln!(code, "    }}")?;
4380    writeln!(code)?;
4381
4382    // ── Private dispatch helpers (all take a shared compute encoder) ──
4383    writeln!(
4384        code,
4385        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4386    )?;
4387    writeln!(
4388        code,
4389        "    // These methods set pipeline state + buffers + dispatch on an existing"
4390    )?;
4391    writeln!(
4392        code,
4393        "    // encoder, avoiding per-operation encoder creation overhead."
4394    )?;
4395    writeln!(code)?;
4396
4397    // dispatch_rms_norm
4398    writeln!(
4399        code,
4400        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4401    )?;
4402    writeln!(
4403        code,
4404        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4405    )?;
4406    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4407    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4408    writeln!(
4409        code,
4410        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4411    )?;
4412    writeln!(
4413        code,
4414        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4415    )?;
4416    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4417    writeln!(
4418        code,
4419        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4420    )?;
4421    writeln!(
4422        code,
4423        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4424    )?;
4425    writeln!(
4426        code,
4427        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4428    )?;
4429    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4430    writeln!(
4431        code,
4432        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4433    )?;
4434    writeln!(
4435        code,
4436        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4437    )?;
4438    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4439    writeln!(code, "    }}")?;
4440    writeln!(code)?;
4441
4442    // dispatch_matmul
4443    writeln!(
4444        code,
4445        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4446    )?;
4447    writeln!(
4448        code,
4449        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4450    )?;
4451    writeln!(code, "        let r: u32 = rows as u32;")?;
4452    writeln!(code, "        let c: u32 = cols as u32;")?;
4453    writeln!(
4454        code,
4455        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4456    )?;
4457    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4458    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4459    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4460    writeln!(
4461        code,
4462        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4463    )?;
4464    writeln!(
4465        code,
4466        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4467    )?;
4468    writeln!(
4469        code,
4470        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4471    )?;
4472    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4473    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4474    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4475    writeln!(
4476        code,
4477        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4478    )?;
4479    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4480    writeln!(code, "    }}")?;
4481    writeln!(code)?;
4482
4483    // dispatch_matmul_q8
4484    writeln!(
4485        code,
4486        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4487    )?;
4488    writeln!(
4489        code,
4490        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4491    )?;
4492    writeln!(
4493        code,
4494        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4495    )?;
4496    writeln!(code, "        let r: u32 = rows as u32;")?;
4497    writeln!(code, "        let c: u32 = cols as u32;")?;
4498    writeln!(
4499        code,
4500        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4501    )?;
4502    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4503    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4504    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4505    writeln!(
4506        code,
4507        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4508    )?;
4509    writeln!(
4510        code,
4511        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4512    )?;
4513    writeln!(
4514        code,
4515        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4516    )?;
4517    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4518    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4519    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4520    writeln!(
4521        code,
4522        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4523    )?;
4524    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4525    writeln!(code, "    }}")?;
4526    writeln!(code)?;
4527
4528    // dispatch_matmul_q4
4529    writeln!(
4530        code,
4531        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4532    )?;
4533    writeln!(
4534        code,
4535        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4536    )?;
4537    writeln!(
4538        code,
4539        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4540    )?;
4541    writeln!(code, "        let r: u32 = rows as u32;")?;
4542    writeln!(code, "        let c: u32 = cols as u32;")?;
4543    writeln!(
4544        code,
4545        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4546    )?;
4547    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4548    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4549    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4550    writeln!(
4551        code,
4552        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4553    )?;
4554    writeln!(
4555        code,
4556        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4557    )?;
4558    writeln!(
4559        code,
4560        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4561    )?;
4562    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4563    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4564    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4565    writeln!(
4566        code,
4567        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4568    )?;
4569    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4570    writeln!(code, "    }}")?;
4571    writeln!(code)?;
4572
4573    // dispatch_rope
4574    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4575    writeln!(
4576        code,
4577        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4578    )?;
4579    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4580    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4581    writeln!(code, "        let p: u32 = pos as u32;")?;
4582    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4583    writeln!(
4584        code,
4585        "        let total_pairs = num_heads * (head_dim / 2);"
4586    )?;
4587    writeln!(
4588        code,
4589        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4590    )?;
4591    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4592    writeln!(
4593        code,
4594        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4595    )?;
4596    writeln!(
4597        code,
4598        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4599    )?;
4600    writeln!(
4601        code,
4602        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4603    )?;
4604    writeln!(
4605        code,
4606        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4607    )?;
4608    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4609    writeln!(
4610        code,
4611        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4612    )?;
4613    writeln!(
4614        code,
4615        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4616    )?;
4617    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4618    writeln!(code, "    }}")?;
4619    writeln!(code)?;
4620
4621    // dispatch_rope_offset
4622    writeln!(
4623        code,
4624        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4625    )?;
4626    writeln!(
4627        code,
4628        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4629    )?;
4630    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4631    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4632    writeln!(code, "        let p: u32 = pos as u32;")?;
4633    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4634    writeln!(
4635        code,
4636        "        let total_pairs = num_heads * (head_dim / 2);"
4637    )?;
4638    writeln!(
4639        code,
4640        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4641    )?;
4642    writeln!(
4643        code,
4644        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4645    )?;
4646    writeln!(
4647        code,
4648        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4649    )?;
4650    writeln!(
4651        code,
4652        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4653    )?;
4654    writeln!(
4655        code,
4656        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4657    )?;
4658    writeln!(
4659        code,
4660        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4661    )?;
4662    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4663    writeln!(
4664        code,
4665        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4666    )?;
4667    writeln!(
4668        code,
4669        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4670    )?;
4671    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4672    writeln!(code, "    }}")?;
4673    writeln!(code)?;
4674
4675    // dispatch_attention
4676    writeln!(
4677        code,
4678        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4679    )?;
4680    writeln!(
4681        code,
4682        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4683    )?;
4684    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4685    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4686    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4687    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4688    writeln!(
4689        code,
4690        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4691    )?;
4692    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4693    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4694    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4695    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4696    writeln!(
4697        code,
4698        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4699    )?;
4700    writeln!(
4701        code,
4702        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4703    )?;
4704    writeln!(
4705        code,
4706        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4707    )?;
4708    writeln!(
4709        code,
4710        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4711    )?;
4712    writeln!(
4713        code,
4714        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4715    )?;
4716    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4717    writeln!(
4718        code,
4719        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4720    )?;
4721    writeln!(
4722        code,
4723        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4724    )?;
4725    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4726    writeln!(code, "    }}")?;
4727    writeln!(code)?;
4728
4729    // dispatch_attention_offset
4730    writeln!(
4731        code,
4732        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4733    )?;
4734    writeln!(
4735        code,
4736        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4737    )?;
4738    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4739    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4740    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4741    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4742    writeln!(
4743        code,
4744        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4745    )?;
4746    writeln!(
4747        code,
4748        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
4749    )?;
4750    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4751    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4752    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4753    writeln!(
4754        code,
4755        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4756    )?;
4757    writeln!(
4758        code,
4759        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4760    )?;
4761    writeln!(
4762        code,
4763        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4764    )?;
4765    writeln!(
4766        code,
4767        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4768    )?;
4769    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4770    writeln!(
4771        code,
4772        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4773    )?;
4774    writeln!(
4775        code,
4776        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4777    )?;
4778    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4779    writeln!(code, "    }}")?;
4780    writeln!(code)?;
4781
4782    // dispatch_silu_mul
4783    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
4784    writeln!(
4785        code,
4786        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
4787    )?;
4788    writeln!(code, "        let count: u32 = n as u32;")?;
4789    writeln!(
4790        code,
4791        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
4792    )?;
4793    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
4794    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
4795    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4796    writeln!(
4797        code,
4798        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4799    )?;
4800    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4801    writeln!(
4802        code,
4803        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4804    )?;
4805    writeln!(
4806        code,
4807        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4808    )?;
4809    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4810    writeln!(code, "    }}")?;
4811    writeln!(code)?;
4812
4813    // dispatch_silu_mul_fused
4814    writeln!(
4815        code,
4816        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
4817    )?;
4818    writeln!(
4819        code,
4820        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
4821    )?;
4822    writeln!(
4823        code,
4824        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
4825    )?;
4826    writeln!(code, "        let count: u32 = n as u32;")?;
4827    writeln!(
4828        code,
4829        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
4830    )?;
4831    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
4832    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
4833    writeln!(
4834        code,
4835        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4836    )?;
4837    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4838    writeln!(
4839        code,
4840        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4841    )?;
4842    writeln!(
4843        code,
4844        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4845    )?;
4846    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4847    writeln!(code, "    }}")?;
4848    writeln!(code)?;
4849
4850    // dispatch_copy (simple src -> dst copy via compute kernel)
4851    writeln!(
4852        code,
4853        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
4854    )?;
4855    writeln!(
4856        code,
4857        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
4858    )?;
4859    writeln!(code, "        let n: u32 = count as u32;")?;
4860    writeln!(
4861        code,
4862        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4863    )?;
4864    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4865    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4866    writeln!(
4867        code,
4868        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4869    )?;
4870    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4871    writeln!(
4872        code,
4873        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4874    )?;
4875    writeln!(
4876        code,
4877        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4878    )?;
4879    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4880    writeln!(code, "    }}")?;
4881    writeln!(code)?;
4882
4883    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
4884    writeln!(
4885        code,
4886        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
4887    )?;
4888    writeln!(
4889        code,
4890        "    /// Used for embedding table lookup (copy a specific row)."
4891    )?;
4892    writeln!(
4893        code,
4894        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
4895    )?;
4896    writeln!(code, "        let off: u32 = src_offset as u32;")?;
4897    writeln!(code, "        let n: u32 = count as u32;")?;
4898    writeln!(
4899        code,
4900        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
4901    )?;
4902    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4903    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4904    writeln!(
4905        code,
4906        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
4907    )?;
4908    writeln!(
4909        code,
4910        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4911    )?;
4912    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4913    writeln!(
4914        code,
4915        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4916    )?;
4917    writeln!(
4918        code,
4919        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4920    )?;
4921    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4922    writeln!(code, "    }}")?;
4923    writeln!(code)?;
4924
4925    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
4926    writeln!(
4927        code,
4928        "    /// Dispatch copy from source at byte offset to destination at float offset."
4929    )?;
4930    writeln!(
4931        code,
4932        "    /// Used for KV cache updates from fused QKV buffer."
4933    )?;
4934    writeln!(
4935        code,
4936        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4937    )?;
4938    writeln!(code, "        let n: u32 = count as u32;")?;
4939    writeln!(
4940        code,
4941        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4942    )?;
4943    writeln!(
4944        code,
4945        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
4946    )?;
4947    writeln!(
4948        code,
4949        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4950    )?;
4951    writeln!(
4952        code,
4953        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4954    )?;
4955    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4956    writeln!(
4957        code,
4958        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4959    )?;
4960    writeln!(
4961        code,
4962        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4963    )?;
4964    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4965    writeln!(code, "    }}")?;
4966    writeln!(code)?;
4967
4968    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
4969    writeln!(
4970        code,
4971        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
4972    )?;
4973    writeln!(
4974        code,
4975        "    /// Used for KV cache updates (write to a specific position in the cache)."
4976    )?;
4977    writeln!(
4978        code,
4979        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
4980    )?;
4981    writeln!(code, "        let n: u32 = count as u32;")?;
4982    writeln!(
4983        code,
4984        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4985    )?;
4986    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4987    writeln!(
4988        code,
4989        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
4990    )?;
4991    writeln!(
4992        code,
4993        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4994    )?;
4995    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4996    writeln!(
4997        code,
4998        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4999    )?;
5000    writeln!(
5001        code,
5002        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5003    )?;
5004    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5005    writeln!(code, "    }}")?;
5006    writeln!(code)?;
5007
5008    // dispatch_add_inplace (residual connection, no blit needed)
5009    writeln!(
5010        code,
5011        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5012    )?;
5013    writeln!(
5014        code,
5015        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5016    )?;
5017    writeln!(code, "        let count: u32 = n as u32;")?;
5018    writeln!(
5019        code,
5020        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5021    )?;
5022    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5023    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5024    writeln!(
5025        code,
5026        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5027    )?;
5028    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5029    writeln!(
5030        code,
5031        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5032    )?;
5033    writeln!(
5034        code,
5035        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5036    )?;
5037    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5038    writeln!(code, "    }}")?;
5039    writeln!(code)?;
5040
5041    // ── Batched prefill dispatch helpers ──
5042    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
5043    writeln!(code)?;
5044
5045    // dispatch_copy_embedding_batch
5046    writeln!(
5047        code,
5048        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
5049    )?;
5050    writeln!(
5051        code,
5052        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5053    )?;
5054    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
5055    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5056    writeln!(
5057        code,
5058        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5059    )?;
5060    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5061    writeln!(
5062        code,
5063        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5064    )?;
5065    writeln!(
5066        code,
5067        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5068    )?;
5069    writeln!(
5070        code,
5071        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5072    )?;
5073    writeln!(
5074        code,
5075        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5076    )?;
5077    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
5078    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5079    writeln!(
5080        code,
5081        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5082    )?;
5083    writeln!(
5084        code,
5085        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5086    )?;
5087    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5088    writeln!(code, "    }}")?;
5089    writeln!(code)?;
5090
5091    // dispatch_rms_norm_batch
5092    writeln!(
5093        code,
5094        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
5095    )?;
5096    writeln!(
5097        code,
5098        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5099    )?;
5100    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
5101    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
5102    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5103    writeln!(
5104        code,
5105        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5106    )?;
5107    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
5108    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
5109    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5110    writeln!(
5111        code,
5112        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5113    )?;
5114    writeln!(
5115        code,
5116        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5117    )?;
5118    writeln!(
5119        code,
5120        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5121    )?;
5122    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5123    writeln!(
5124        code,
5125        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5126    )?;
5127    writeln!(
5128        code,
5129        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5130    )?;
5131    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5132    writeln!(code, "    }}")?;
5133    writeln!(code)?;
5134
5135    // dispatch_matmul_batch (f32)
5136    writeln!(
5137        code,
5138        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5139    )?;
5140    writeln!(
5141        code,
5142        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5143    )?;
5144    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5145    writeln!(code, "        let r: u32 = rows as u32;")?;
5146    writeln!(code, "        let c: u32 = cols as u32;")?;
5147    writeln!(
5148        code,
5149        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5150    )?;
5151    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5152    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5153    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5154    writeln!(
5155        code,
5156        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5157    )?;
5158    writeln!(
5159        code,
5160        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5161    )?;
5162    writeln!(
5163        code,
5164        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5165    )?;
5166    writeln!(
5167        code,
5168        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
5169    )?;
5170    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5171    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5172    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5173    writeln!(
5174        code,
5175        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5176    )?;
5177    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5178    writeln!(code, "    }}")?;
5179    writeln!(code)?;
5180
5181    // dispatch_matmul_q8_batch
5182    writeln!(
5183        code,
5184        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5185    )?;
5186    writeln!(code, "    ///")?;
5187    writeln!(
5188        code,
5189        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5190    )?;
5191    writeln!(
5192        code,
5193        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5194    )?;
5195    writeln!(
5196        code,
5197        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5198    )?;
5199    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5200    writeln!(code, "        let r: u32 = rows as u32;")?;
5201    writeln!(code, "        let c: u32 = cols as u32;")?;
5202    writeln!(
5203        code,
5204        "        // Tile sizes must match the Metal shader constants."
5205    )?;
5206    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
5207    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
5208    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
5209    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
5210    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
5211    writeln!(
5212        code,
5213        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
5214    )?;
5215    writeln!(
5216        code,
5217        "        // Prefer the large 32×32 tile when the problem supports it — halves"
5218    )?;
5219    writeln!(
5220        code,
5221        "        // dispatch count and reuses each weight load across 32 tokens."
5222    )?;
5223    writeln!(
5224        code,
5225        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5226    )?;
5227    writeln!(
5228        code,
5229        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5230    )?;
5231    writeln!(
5232        code,
5233        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
5234    )?;
5235    writeln!(
5236        code,
5237        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5238    )?;
5239    writeln!(
5240        code,
5241        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
5242    )?;
5243    writeln!(
5244        code,
5245        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5246    )?;
5247    writeln!(
5248        code,
5249        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5250    )?;
5251    writeln!(
5252        code,
5253        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5254    )?;
5255    writeln!(
5256        code,
5257        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5258    )?;
5259    writeln!(code, "            let use_h4 = cols >= 2048;")?;
5260    writeln!(code, "            let pipe = if use_h4 {{")?;
5261    writeln!(code, "                if num_tokens >= 256 {{")?;
5262    writeln!(
5263        code,
5264        "                    &self.matmul_q8_mma32_hh4_pipeline"
5265    )?;
5266    writeln!(code, "                }} else {{")?;
5267    writeln!(
5268        code,
5269        "                    &self.matmul_q8_mma32_h4_pipeline"
5270    )?;
5271    writeln!(code, "                }}")?;
5272    writeln!(code, "            }} else {{")?;
5273    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
5274    writeln!(code, "            }};")?;
5275    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
5276    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5277    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5278    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5279    writeln!(
5280        code,
5281        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5282    )?;
5283    writeln!(
5284        code,
5285        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5286    )?;
5287    writeln!(
5288        code,
5289        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5290    )?;
5291    writeln!(code, "            let row_tgs = rows / MMA32_ROW_TILE;")?;
5292    writeln!(
5293        code,
5294        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5295    )?;
5296    writeln!(
5297        code,
5298        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5299    )?;
5300    writeln!(
5301        code,
5302        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5303    )?;
5304    writeln!(
5305        code,
5306        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5307    )?;
5308    writeln!(
5309        code,
5310        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5311    )?;
5312    writeln!(
5313        code,
5314        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5315    )?;
5316    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5317    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5318    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5319    writeln!(
5320        code,
5321        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5322    )?;
5323    writeln!(
5324        code,
5325        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5326    )?;
5327    writeln!(
5328        code,
5329        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5330    )?;
5331    writeln!(code, "            let row_tgs = rows / MMA_ROW_TILE;")?;
5332    writeln!(
5333        code,
5334        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5335    )?;
5336    writeln!(
5337        code,
5338        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5339    )?;
5340    writeln!(
5341        code,
5342        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5343    )?;
5344    writeln!(
5345        code,
5346        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5347    )?;
5348    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5349    writeln!(
5350        code,
5351        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5352    )?;
5353    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5354    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5355    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5356    writeln!(
5357        code,
5358        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5359    )?;
5360    writeln!(
5361        code,
5362        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5363    )?;
5364    writeln!(
5365        code,
5366        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5367    )?;
5368    writeln!(
5369        code,
5370        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5371    )?;
5372    writeln!(
5373        code,
5374        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5375    )?;
5376    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5377    writeln!(
5378        code,
5379        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5380    )?;
5381    writeln!(
5382        code,
5383        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5384    )?;
5385    writeln!(code, "        }} else {{")?;
5386    writeln!(
5387        code,
5388        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5389    )?;
5390    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5391    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5392    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5393    writeln!(
5394        code,
5395        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5396    )?;
5397    writeln!(
5398        code,
5399        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5400    )?;
5401    writeln!(
5402        code,
5403        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5404    )?;
5405    writeln!(
5406        code,
5407        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5408    )?;
5409    writeln!(
5410        code,
5411        "            let num_tg = (row_tgs * num_tokens) as u64;"
5412    )?;
5413    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5414    writeln!(
5415        code,
5416        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5417    )?;
5418    writeln!(
5419        code,
5420        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5421    )?;
5422    writeln!(code, "        }}")?;
5423    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5424    writeln!(code, "    }}")?;
5425    writeln!(code)?;
5426
5427    // dispatch_matmul_q4_batch
5428    writeln!(
5429        code,
5430        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5431    )?;
5432    writeln!(
5433        code,
5434        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5435    )?;
5436    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5437    writeln!(code, "        let r: u32 = rows as u32;")?;
5438    writeln!(code, "        let c: u32 = cols as u32;")?;
5439    writeln!(
5440        code,
5441        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5442    )?;
5443    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5444    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5445    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5446    writeln!(
5447        code,
5448        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5449    )?;
5450    writeln!(
5451        code,
5452        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5453    )?;
5454    writeln!(
5455        code,
5456        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5457    )?;
5458    writeln!(
5459        code,
5460        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5461    )?;
5462    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5463    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5464    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5465    writeln!(
5466        code,
5467        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5468    )?;
5469    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5470    writeln!(code, "    }}")?;
5471    writeln!(code)?;
5472
5473    // dispatch_add_bias_batch — Qwen2 QKV bias broadcast-add after fused qkv matmul.
5474    if config.qkv_bias {
5475        writeln!(
5476            code,
5477            "    /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5478        )?;
5479        writeln!(
5480            code,
5481            "    fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5482        )?;
5483        writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5484        writeln!(code, "        let r: u32 = rows as u32;")?;
5485        writeln!(
5486            code,
5487            "        enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5488        )?;
5489        writeln!(code, "        enc.set_buffer(0, Some(out), 0);")?;
5490        writeln!(code, "        enc.set_buffer(1, Some(bias), 0);")?;
5491        writeln!(
5492            code,
5493            "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5494        )?;
5495        writeln!(
5496            code,
5497            "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5498        )?;
5499        writeln!(code, "        let total = (num_tokens * rows) as u64;")?;
5500        writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5501        writeln!(
5502            code,
5503            "        let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5504        )?;
5505        writeln!(
5506            code,
5507            "        enc.dispatch_thread_groups(grid_size, tg_size);"
5508        )?;
5509        writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5510        writeln!(code, "    }}")?;
5511        writeln!(code)?;
5512    }
5513
5514    // dispatch_rope_batch
5515    writeln!(
5516        code,
5517        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5518    )?;
5519    writeln!(
5520        code,
5521        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5522    )?;
5523    writeln!(
5524        code,
5525        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5526    )?;
5527    writeln!(
5528        code,
5529        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5530    )?;
5531    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5532    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5533    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5534    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5535    writeln!(
5536        code,
5537        "        let pairs_per_token = num_heads * (head_dim / 2);"
5538    )?;
5539    writeln!(
5540        code,
5541        "        let total_pairs = num_tokens * pairs_per_token;"
5542    )?;
5543    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5544    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5545    // we need to pass the buffer at the right byte offset for each token's data.
5546    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5547    // but our layout is data[token * row_stride + data_float_offset + ...].
5548    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5549    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5550    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5551    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5552    //
5553    // Simplest approach: use the single-token rope kernel for each token in a loop.
5554    // This is still efficient because we're dispatching all within the same command encoder.
5555    writeln!(
5556        code,
5557        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5558    )?;
5559    writeln!(code, "        for t in 0..num_tokens {{")?;
5560    writeln!(
5561        code,
5562        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5563    )?;
5564    writeln!(
5565        code,
5566        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5567    )?;
5568    writeln!(
5569        code,
5570        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5571    )?;
5572    writeln!(
5573        code,
5574        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5575    )?;
5576    writeln!(
5577        code,
5578        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5579    )?;
5580    writeln!(
5581        code,
5582        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5583    )?;
5584    writeln!(
5585        code,
5586        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5587    )?;
5588    writeln!(
5589        code,
5590        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5591    )?;
5592    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5593    writeln!(
5594        code,
5595        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5596    )?;
5597    writeln!(
5598        code,
5599        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5600    )?;
5601    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5602    writeln!(code, "        }}")?;
5603    writeln!(code, "    }}")?;
5604    writeln!(code)?;
5605
5606    // dispatch_silu_mul_fused_batch
5607    writeln!(
5608        code,
5609        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5610    )?;
5611    writeln!(
5612        code,
5613        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5614    )?;
5615    writeln!(code, "        let count: u32 = n as u32;")?;
5616    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5617    writeln!(
5618        code,
5619        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5620    )?;
5621    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5622    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5623    writeln!(
5624        code,
5625        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5626    )?;
5627    writeln!(
5628        code,
5629        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5630    )?;
5631    writeln!(code, "        let total = num_tokens * n;")?;
5632    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5633    writeln!(
5634        code,
5635        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5636    )?;
5637    writeln!(
5638        code,
5639        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5640    )?;
5641    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5642    writeln!(code, "    }}")?;
5643    writeln!(code)?;
5644
5645    // dispatch_add_inplace_batch_n (add n elements in-place)
5646    writeln!(
5647        code,
5648        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5649    )?;
5650    writeln!(
5651        code,
5652        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5653    )?;
5654    writeln!(code, "        let count: u32 = total_n as u32;")?;
5655    writeln!(
5656        code,
5657        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5658    )?;
5659    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5660    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5661    writeln!(
5662        code,
5663        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5664    )?;
5665    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5666    writeln!(
5667        code,
5668        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5669    )?;
5670    writeln!(
5671        code,
5672        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5673    )?;
5674    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5675    writeln!(code, "    }}")?;
5676    writeln!(code)?;
5677
5678    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
5679    writeln!(
5680        code,
5681        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
5682    )?;
5683    writeln!(
5684        code,
5685        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5686    )?;
5687    writeln!(code, "        let n: u32 = count as u32;")?;
5688    writeln!(
5689        code,
5690        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5691    )?;
5692    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5693    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5694    writeln!(
5695        code,
5696        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5697    )?;
5698    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5699    writeln!(
5700        code,
5701        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5702    )?;
5703    writeln!(
5704        code,
5705        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5706    )?;
5707    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5708    writeln!(code, "    }}")?;
5709    writeln!(code)?;
5710
5711    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
5712    writeln!(
5713        code,
5714        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
5715    )?;
5716    writeln!(
5717        code,
5718        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5719    )?;
5720    writeln!(code, "        let n: u32 = count as u32;")?;
5721    writeln!(
5722        code,
5723        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5724    )?;
5725    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5726    writeln!(
5727        code,
5728        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5729    )?;
5730    writeln!(
5731        code,
5732        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5733    )?;
5734    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5735    writeln!(
5736        code,
5737        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5738    )?;
5739    writeln!(
5740        code,
5741        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5742    )?;
5743    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5744    writeln!(code, "    }}")?;
5745    writeln!(code)?;
5746
5747    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
5748    writeln!(
5749        code,
5750        "    /// Copy from src at byte offset to dst at float offset."
5751    )?;
5752    writeln!(
5753        code,
5754        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5755    )?;
5756    writeln!(code, "        let n: u32 = count as u32;")?;
5757    writeln!(
5758        code,
5759        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5760    )?;
5761    writeln!(
5762        code,
5763        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5764    )?;
5765    writeln!(
5766        code,
5767        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5768    )?;
5769    writeln!(
5770        code,
5771        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5772    )?;
5773    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5774    writeln!(
5775        code,
5776        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5777    )?;
5778    writeln!(
5779        code,
5780        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5781    )?;
5782    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5783    writeln!(code, "    }}")?;
5784    writeln!(code)?;
5785
5786    // dispatch_copy_kv_batch
5787    writeln!(
5788        code,
5789        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
5790    )?;
5791    writeln!(
5792        code,
5793        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
5794    )?;
5795    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5796    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
5797    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5798    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
5799    writeln!(code, "        let so: u32 = src_offset as u32;")?;
5800    writeln!(
5801        code,
5802        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
5803    )?;
5804    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5805    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5806    writeln!(
5807        code,
5808        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5809    )?;
5810    writeln!(
5811        code,
5812        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
5813    )?;
5814    writeln!(
5815        code,
5816        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5817    )?;
5818    writeln!(
5819        code,
5820        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
5821    )?;
5822    writeln!(
5823        code,
5824        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
5825    )?;
5826    writeln!(code, "        let total = num_tokens * kv_dim;")?;
5827    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5828    writeln!(
5829        code,
5830        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5831    )?;
5832    writeln!(
5833        code,
5834        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5835    )?;
5836    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5837    writeln!(code, "    }}")?;
5838    writeln!(code)?;
5839
5840    // dispatch_attention_batch
5841    writeln!(
5842        code,
5843        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
5844    )?;
5845    writeln!(
5846        code,
5847        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
5848    )?;
5849    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5850    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5851    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5852    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5853    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5854    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
5855    // The legacy attention_batch kernel materializes a scores[2048] array
5856    // in threadgroup memory — fast at short seq_len but silently corrupts
5857    // when `base_pos + num_tokens > 2048`.  Switch to the flash kernel
5858    // (tiled online softmax, O(head_dim) threadgroup memory, no seq_len
5859    // cap) above that boundary.  Below it the legacy kernel is ~15% faster
5860    // due to the flash kernel's per-tile barrier overhead at low seq_len.
5861    writeln!(code, "        let max_seq = base_pos + num_tokens;")?;
5862    // Re-testing after the chunking + scores[4096] fixes showed the flash
5863    // kernel is numerically correct — the earlier "garbled output" was from
5864    // the scores[2048] overflow, not from flash itself.  Benchmarks show
5865    // flash is 7–14% slower than legacy because its online-softmax path has
5866    // more per-tile barriers and no MMA acceleration (yet).  Default to the
5867    // legacy kernel; keep flash wired up for future use (larger max_seq_len
5868    // caps, or as the base for an MMA-accelerated flash variant).
5869    writeln!(code, "        let _ = max_seq;")?;
5870    writeln!(code, "        let pipe = &self.attention_batch_pipeline;")?;
5871    writeln!(code, "        enc.set_compute_pipeline_state(pipe);")?;
5872    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
5873    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5874    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5875    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5876    writeln!(
5877        code,
5878        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5879    )?;
5880    writeln!(
5881        code,
5882        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5883    )?;
5884    writeln!(
5885        code,
5886        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5887    )?;
5888    writeln!(
5889        code,
5890        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5891    )?;
5892    writeln!(
5893        code,
5894        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5895    )?;
5896    writeln!(
5897        code,
5898        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5899    )?;
5900    writeln!(
5901        code,
5902        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
5903    )?;
5904    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5905    writeln!(
5906        code,
5907        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
5908    )?;
5909    writeln!(
5910        code,
5911        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5912    )?;
5913    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5914    writeln!(code, "    }}")?;
5915    writeln!(code)?;
5916
5917    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
5918    writeln!(
5919        code,
5920        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
5921    )?;
5922    writeln!(
5923        code,
5924        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
5925    )?;
5926    writeln!(
5927        code,
5928        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
5929    )?;
5930    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5931    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5932    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
5933    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5934    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5935    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
5936    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5937    writeln!(
5938        code,
5939        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
5940    )?;
5941    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
5942    writeln!(
5943        code,
5944        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5945    )?;
5946    writeln!(
5947        code,
5948        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5949    )?;
5950    writeln!(
5951        code,
5952        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
5953    )?;
5954    writeln!(
5955        code,
5956        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5957    )?;
5958    writeln!(
5959        code,
5960        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5961    )?;
5962    writeln!(
5963        code,
5964        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5965    )?;
5966    writeln!(
5967        code,
5968        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5969    )?;
5970    writeln!(
5971        code,
5972        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
5973    )?;
5974    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5975    writeln!(
5976        code,
5977        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5978    )?;
5979    writeln!(
5980        code,
5981        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5982    )?;
5983    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5984    writeln!(code, "    }}")?;
5985    writeln!(code)?;
5986
5987    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
5988    writeln!(
5989        code,
5990        "    /// Dispatch fused K+V cache copy in one kernel launch."
5991    )?;
5992    writeln!(
5993        code,
5994        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
5995    )?;
5996    writeln!(
5997        code,
5998        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
5999    )?;
6000    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6001    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6002    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6003    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6004    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
6005    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
6006    writeln!(
6007        code,
6008        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6009    )?;
6010    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6011    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
6012    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
6013    writeln!(
6014        code,
6015        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6016    )?;
6017    writeln!(
6018        code,
6019        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6020    )?;
6021    writeln!(
6022        code,
6023        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6024    )?;
6025    writeln!(
6026        code,
6027        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6028    )?;
6029    writeln!(
6030        code,
6031        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6032    )?;
6033    writeln!(
6034        code,
6035        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6036    )?;
6037    writeln!(
6038        code,
6039        "        let total = num_tokens * kv_dim * 2;  // K + V"
6040    )?;
6041    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6042    writeln!(
6043        code,
6044        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6045    )?;
6046    writeln!(
6047        code,
6048        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6049    )?;
6050    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6051    writeln!(code, "    }}")?;
6052
6053    writeln!(code, "}}")?;
6054    writeln!(code)?;
6055
6056    Ok(())
6057}
6058
6059fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6060    writeln!(
6061        code,
6062        "// ── Helper functions ──────────────────────────────────"
6063    )?;
6064    writeln!(code)?;
6065    writeln!(
6066        code,
6067        "/// Create a compute pipeline from a named function in the Metal library."
6068    )?;
6069    writeln!(
6070        code,
6071        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6072    )?;
6073    writeln!(
6074        code,
6075        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6076    )?;
6077    writeln!(
6078        code,
6079        "    device.new_compute_pipeline_state_with_function(&func)"
6080    )?;
6081    writeln!(
6082        code,
6083        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6084    )?;
6085    writeln!(code, "}}")?;
6086    writeln!(code)?;
6087
6088    Ok(())
6089}
6090
6091// ---------------------------------------------------------------------------
6092// main.rs generation
6093// ---------------------------------------------------------------------------
6094
6095fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6096    let _sanitized = sanitize_name(model_name);
6097    let _vocab = config.vocab_size;
6098
6099    let mut code = String::with_capacity(16 * 1024);
6100    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6101    writeln!(
6102        code,
6103        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6104    )?;
6105    writeln!(code)?;
6106    writeln!(code, "mod model;")?;
6107    writeln!(code)?;
6108    writeln!(code, "use std::io::Write;")?;
6109    writeln!(code, "use std::time::Instant;")?;
6110    writeln!(code, "use serde::Deserialize;")?;
6111    writeln!(code)?;
6112
6113    // -- main function --
6114    writeln!(code, "fn main() {{")?;
6115    writeln!(
6116        code,
6117        "    let args: Vec<String> = std::env::args().collect();"
6118    )?;
6119    writeln!(code)?;
6120    writeln!(
6121        code,
6122        "    // Detect --serve mode (only requires weights + tokenizer)"
6123    )?;
6124    writeln!(
6125        code,
6126        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
6127    )?;
6128    writeln!(code)?;
6129    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
6130    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6131    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6132    writeln!(code, "        std::process::exit(1);")?;
6133    writeln!(code, "    }}")?;
6134    writeln!(code)?;
6135    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
6136    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6137    writeln!(code, "        std::process::exit(1);")?;
6138    writeln!(code, "    }}")?;
6139    writeln!(code)?;
6140    writeln!(code, "    let weights_path = &args[1];")?;
6141    writeln!(code, "    let tokenizer_path = &args[2];")?;
6142    writeln!(code)?;
6143    writeln!(code, "    // Parse optional flags")?;
6144    writeln!(code, "    let mut max_tokens: usize = 128;")?;
6145    writeln!(code, "    let mut port: u16 = 8080;")?;
6146    writeln!(
6147        code,
6148        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6149    )?;
6150    writeln!(
6151        code,
6152        "    let profile = args.iter().any(|a| a == \"--profile\");"
6153    )?;
6154    writeln!(code, "    let mut i = 3;")?;
6155    writeln!(code, "    while i < args.len() {{")?;
6156    writeln!(
6157        code,
6158        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6159    )?;
6160    writeln!(
6161        code,
6162        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
6163    )?;
6164    writeln!(code, "            i += 2;")?;
6165    writeln!(
6166        code,
6167        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6168    )?;
6169    writeln!(
6170        code,
6171        "            port = args[i + 1].parse().unwrap_or(8080);"
6172    )?;
6173    writeln!(code, "            i += 2;")?;
6174    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
6175    writeln!(code, "            i += 1;")?;
6176    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
6177    writeln!(code, "            i += 1;")?;
6178    writeln!(code, "        }} else {{")?;
6179    writeln!(code, "            i += 1;")?;
6180    writeln!(code, "        }}")?;
6181    writeln!(code, "    }}")?;
6182    writeln!(code)?;
6183
6184    // -- load model (shared by both modes) --
6185    writeln!(
6186        code,
6187        "    // Memory-map weights for zero-copy loading on Apple Silicon"
6188    )?;
6189    writeln!(
6190        code,
6191        "    let weights_file = std::fs::File::open(weights_path)"
6192    )?;
6193    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6194    writeln!(
6195        code,
6196        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6197    )?;
6198    writeln!(code)?;
6199    writeln!(code, "    // Load tokenizer")?;
6200    writeln!(
6201        code,
6202        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6203    )?;
6204    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6205    writeln!(code)?;
6206    writeln!(code, "    // Create Metal model")?;
6207    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
6208    writeln!(
6209        code,
6210        "    let mut model = model::MetalModel::new(&weights_mmap);"
6211    )?;
6212    writeln!(code)?;
6213
6214    // -- branch: serve vs CLI --
6215    writeln!(code, "    if serve_mode {{")?;
6216    writeln!(code, "        serve(model, tokenizer, port);")?;
6217    writeln!(code, "    }} else {{")?;
6218    writeln!(code, "        let prompt = &args[3];")?;
6219    writeln!(
6220        code,
6221        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6222    )?;
6223    writeln!(code, "    }}")?;
6224    writeln!(code, "}}")?;
6225    writeln!(code)?;
6226
6227    // -- cli_mode function --
6228    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6229    writeln!(code, "    // Tokenize prompt")?;
6230    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
6231    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6232    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6233    writeln!(code)?;
6234    writeln!(
6235        code,
6236        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6237    )?;
6238    writeln!(
6239        code,
6240        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
6241    )?;
6242    writeln!(
6243        code,
6244        "    // The last token uses synchronous forward() to get logits."
6245    )?;
6246    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
6247    writeln!(code, "    let prefill_start = Instant::now();")?;
6248    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
6249    writeln!(
6250        code,
6251        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6252    )?;
6253    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
6254    writeln!(code, "    }} else {{")?;
6255    writeln!(code, "        model.forward(prompt_tokens[0])")?;
6256    writeln!(code, "    }};")?;
6257    writeln!(
6258        code,
6259        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6260    )?;
6261    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
6262    writeln!(
6263        code,
6264        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6265    )?;
6266    writeln!(
6267        code,
6268        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6269    )?;
6270    writeln!(code)?;
6271    writeln!(code, "    // Generate tokens")?;
6272    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6273    writeln!(code, "    let gen_start = Instant::now();")?;
6274    writeln!(code, "    let mut generated_count: usize = 0;")?;
6275    writeln!(code)?;
6276    writeln!(code, "    for _ in 0..max_tokens {{")?;
6277    writeln!(
6278        code,
6279        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6280    )?;
6281    writeln!(code, "            if !quiet {{")?;
6282    writeln!(code, "                print!(\"{{}}\", text);")?;
6283    writeln!(code, "                std::io::stdout().flush().ok();")?;
6284    writeln!(code, "            }}")?;
6285    writeln!(code, "        }}")?;
6286    writeln!(code, "        generated_count += 1;")?;
6287    writeln!(code)?;
6288    writeln!(
6289        code,
6290        "        // Use profiling forward for first token when --profile is set"
6291    )?;
6292    writeln!(
6293        code,
6294        "        let logits = if profile && generated_count == 1 {{"
6295    )?;
6296    writeln!(code, "            model.forward_profile(next_token)")?;
6297    writeln!(code, "        }} else {{")?;
6298    writeln!(code, "            model.forward(next_token)")?;
6299    writeln!(code, "        }};")?;
6300    writeln!(code, "        next_token = argmax(&logits);")?;
6301    writeln!(code)?;
6302    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
6303    writeln!(code, "        if next_token == 2 {{")?;
6304    writeln!(code, "            break;")?;
6305    writeln!(code, "        }}")?;
6306    writeln!(code)?;
6307    writeln!(
6308        code,
6309        "        // Yield between tokens to reduce sustained GPU thermal load."
6310    )?;
6311    writeln!(
6312        code,
6313        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6314    )?;
6315    writeln!(
6316        code,
6317        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6318    )?;
6319    writeln!(
6320        code,
6321        "        // briefly, providing a micro-break that helps sustain peak throughput."
6322    )?;
6323    writeln!(code, "        std::thread::yield_now();")?;
6324    writeln!(code, "    }}")?;
6325    writeln!(code, "    if !quiet {{")?;
6326    writeln!(code, "        println!();")?;
6327    writeln!(code, "    }}")?;
6328    writeln!(
6329        code,
6330        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6331    )?;
6332    writeln!(
6333        code,
6334        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6335    )?;
6336    writeln!(
6337        code,
6338        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6339    )?;
6340    writeln!(code, "}}")?;
6341    writeln!(code)?;
6342
6343    // -- argmax helper --
6344    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6345    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6346    writeln!(code, "    logits.iter()")?;
6347    writeln!(code, "        .enumerate()")?;
6348    writeln!(
6349        code,
6350        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6351    )?;
6352    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6353    writeln!(code, "        .unwrap_or(0)")?;
6354    writeln!(code, "}}")?;
6355    writeln!(code)?;
6356
6357    // -- Request/Response types for OpenAI API --
6358    writeln!(
6359        code,
6360        "// -----------------------------------------------------------------------"
6361    )?;
6362    writeln!(code, "// OpenAI-compatible API server")?;
6363    writeln!(
6364        code,
6365        "// -----------------------------------------------------------------------"
6366    )?;
6367    writeln!(code)?;
6368    writeln!(code, "#[derive(Deserialize)]")?;
6369    writeln!(code, "struct ChatRequest {{")?;
6370    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6371    writeln!(code, "    #[serde(default)]")?;
6372    writeln!(code, "    stream: Option<bool>,")?;
6373    writeln!(code, "    #[serde(default)]")?;
6374    writeln!(code, "    max_tokens: Option<usize>,")?;
6375    writeln!(code, "    #[serde(default)]")?;
6376    writeln!(code, "    temperature: Option<f32>,")?;
6377    writeln!(code, "    #[serde(default)]")?;
6378    writeln!(code, "    model: Option<String>,")?;
6379    writeln!(code, "}}")?;
6380    writeln!(code)?;
6381    writeln!(code, "#[derive(Deserialize)]")?;
6382    writeln!(code, "struct ChatMessage {{")?;
6383    writeln!(code, "    role: String,")?;
6384    writeln!(code, "    content: String,")?;
6385    writeln!(code, "}}")?;
6386    writeln!(code)?;
6387
6388    // -- format_chat_messages --
6389    writeln!(
6390        code,
6391        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6392    )?;
6393    writeln!(code, "    let mut prompt = String::new();")?;
6394    writeln!(code, "    for msg in messages {{")?;
6395    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6396    writeln!(code, "    }}")?;
6397    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6398    writeln!(code, "    prompt")?;
6399    writeln!(code, "}}")?;
6400    writeln!(code)?;
6401
6402    // -- prefill helper --
6403    writeln!(
6404        code,
6405        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6406    )?;
6407    writeln!(code, "    let len = tokens.len();")?;
6408    writeln!(code, "    if len > 1 {{")?;
6409    writeln!(
6410        code,
6411        "        model.forward_prefill_batch(&tokens[..len - 1]);"
6412    )?;
6413    writeln!(code, "    }}")?;
6414    writeln!(code, "    model.forward(tokens[len - 1])")?;
6415    writeln!(code, "}}")?;
6416    writeln!(code)?;
6417
6418    // -- serve function --
6419    writeln!(
6420        code,
6421        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6422    )?;
6423    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6424    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
6425    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6426    writeln!(
6427        code,
6428        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6429    )?;
6430    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
6431    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
6432    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
6433    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
6434    writeln!(code)?;
6435    writeln!(code, "    for request in server.incoming_requests() {{")?;
6436    writeln!(code, "        let method = request.method().to_string();")?;
6437    writeln!(code, "        let url = request.url().to_string();")?;
6438    writeln!(code)?;
6439    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
6440
6441    // -- POST /v1/chat/completions --
6442    writeln!(
6443        code,
6444        "            (\"POST\", \"/v1/chat/completions\") => {{"
6445    )?;
6446    writeln!(
6447        code,
6448        "                handle_chat_completion(&mut model, &tokenizer, request);"
6449    )?;
6450    writeln!(code, "            }}")?;
6451
6452    // -- GET /v1/models --
6453    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
6454    writeln!(code, "                let body = serde_json::json!({{")?;
6455    writeln!(code, "                    \"object\": \"list\",")?;
6456    writeln!(code, "                    \"data\": [{{")?;
6457    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
6458    writeln!(code, "                        \"object\": \"model\",")?;
6459    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
6460    writeln!(code, "                    }}]")?;
6461    writeln!(code, "                }});")?;
6462    writeln!(
6463        code,
6464        "                let resp = tiny_http::Response::from_string(body.to_string())"
6465    )?;
6466    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6467    writeln!(code, "                request.respond(resp).ok();")?;
6468    writeln!(code, "            }}")?;
6469
6470    // -- GET /health --
6471    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
6472    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6473    writeln!(code, "                request.respond(resp).ok();")?;
6474    writeln!(code, "            }}")?;
6475
6476    // -- 404 --
6477    writeln!(code, "            _ => {{")?;
6478    writeln!(
6479        code,
6480        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
6481    )?;
6482    writeln!(code, "                    .with_status_code(404);")?;
6483    writeln!(code, "                request.respond(resp).ok();")?;
6484    writeln!(code, "            }}")?;
6485    writeln!(code, "        }}")?;
6486    writeln!(code, "    }}")?;
6487    writeln!(code, "}}")?;
6488    writeln!(code)?;
6489
6490    // -- handle_chat_completion --
6491    writeln!(code, "fn handle_chat_completion(")?;
6492    writeln!(code, "    model: &mut model::MetalModel,")?;
6493    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6494    writeln!(code, "    mut request: tiny_http::Request,")?;
6495    writeln!(code, ") {{")?;
6496    writeln!(code, "    // Read request body")?;
6497    writeln!(code, "    let mut body = String::new();")?;
6498    writeln!(
6499        code,
6500        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6501    )?;
6502    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6503    writeln!(code, "            .with_status_code(400);")?;
6504    writeln!(code, "        request.respond(resp).ok();")?;
6505    writeln!(code, "        return;")?;
6506    writeln!(code, "    }}")?;
6507    writeln!(code)?;
6508    writeln!(code, "    // Parse JSON")?;
6509    writeln!(
6510        code,
6511        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6512    )?;
6513    writeln!(code, "        Ok(r) => r,")?;
6514    writeln!(code, "        Err(e) => {{")?;
6515    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6516    writeln!(code, "                .with_status_code(400);")?;
6517    writeln!(code, "            request.respond(resp).ok();")?;
6518    writeln!(code, "            return;")?;
6519    writeln!(code, "        }}")?;
6520    writeln!(code, "    }};")?;
6521    writeln!(code)?;
6522    writeln!(
6523        code,
6524        "    let prompt = format_chat_messages(&req.messages);"
6525    )?;
6526    writeln!(
6527        code,
6528        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6529    )?;
6530    writeln!(code, "        Ok(e) => e,")?;
6531    writeln!(code, "        Err(e) => {{")?;
6532    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6533    writeln!(code, "                .with_status_code(500);")?;
6534    writeln!(code, "            request.respond(resp).ok();")?;
6535    writeln!(code, "            return;")?;
6536    writeln!(code, "        }}")?;
6537    writeln!(code, "    }};")?;
6538    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6539    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6540    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6541    writeln!(
6542        code,
6543        "    let _temperature = req.temperature.unwrap_or(1.0);"
6544    )?;
6545    writeln!(code)?;
6546
6547    // -- Reset KV cache for each request --
6548    writeln!(code, "    model.reset();")?;
6549    writeln!(code)?;
6550
6551    // -- Prefill with timing --
6552    writeln!(code, "    let prefill_start = Instant::now();")?;
6553    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6554    writeln!(
6555        code,
6556        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6557    )?;
6558    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6559    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6560    writeln!(code)?;
6561
6562    writeln!(code, "    if stream {{")?;
6563
6564    // -- SSE streaming response --
6565    writeln!(
6566        code,
6567        "        // SSE streaming: generate tokens and build SSE body"
6568    )?;
6569    writeln!(code, "        let gen_start = Instant::now();")?;
6570    writeln!(code, "        let mut generated_count: usize = 0;")?;
6571    writeln!(code, "        let mut sse_body = String::new();")?;
6572    writeln!(code, "        for _ in 0..max_tokens {{")?;
6573    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6574    writeln!(
6575        code,
6576        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6577    )?;
6578    writeln!(
6579        code,
6580        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6581    )?;
6582    writeln!(
6583        code,
6584        "                // escaped includes surrounding quotes, strip them"
6585    )?;
6586    writeln!(
6587        code,
6588        "                let inner = &escaped[1..escaped.len()-1];"
6589    )?;
6590    writeln!(code, "                sse_body.push_str(&format!(")?;
6591    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6592    writeln!(code, "                    inner")?;
6593    writeln!(code, "                ));")?;
6594    writeln!(code, "            }}")?;
6595    writeln!(code, "            generated_count += 1;")?;
6596    writeln!(code, "            let logits = model.forward(next_token);")?;
6597    writeln!(code, "            next_token = argmax(&logits);")?;
6598    writeln!(code, "        }}")?;
6599    writeln!(
6600        code,
6601        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6602    )?;
6603    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6604    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
6605    writeln!(code)?;
6606    writeln!(
6607        code,
6608        "        // Final chunk with finish_reason, timing, and DONE sentinel"
6609    )?;
6610    writeln!(code, "        sse_body.push_str(&format!(")?;
6611    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6612    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6613    writeln!(code, "        ));")?;
6614    writeln!(code)?;
6615    writeln!(
6616        code,
6617        "        let resp = tiny_http::Response::from_string(sse_body)"
6618    )?;
6619    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6620    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6621    writeln!(code, "        request.respond(resp).ok();")?;
6622
6623    writeln!(code, "    }} else {{")?;
6624
6625    // -- Non-streaming response --
6626    writeln!(
6627        code,
6628        "        // Non-streaming: generate all tokens, return JSON"
6629    )?;
6630    writeln!(code, "        let gen_start = Instant::now();")?;
6631    writeln!(code, "        let mut generated_count: usize = 0;")?;
6632    writeln!(code, "        let mut generated = String::new();")?;
6633    writeln!(code, "        for _ in 0..max_tokens {{")?;
6634    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6635    writeln!(
6636        code,
6637        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6638    )?;
6639    writeln!(code, "                generated.push_str(&text);")?;
6640    writeln!(code, "            }}")?;
6641    writeln!(code, "            generated_count += 1;")?;
6642    writeln!(code, "            let logits = model.forward(next_token);")?;
6643    writeln!(code, "            next_token = argmax(&logits);")?;
6644    writeln!(code, "        }}")?;
6645    writeln!(
6646        code,
6647        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6648    )?;
6649    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6650    writeln!(code)?;
6651    writeln!(code, "        let resp_json = serde_json::json!({{")?;
6652    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
6653    writeln!(code, "            \"object\": \"chat.completion\",")?;
6654    writeln!(code, "            \"choices\": [{{")?;
6655    writeln!(code, "                \"index\": 0,")?;
6656    writeln!(code, "                \"message\": {{")?;
6657    writeln!(code, "                    \"role\": \"assistant\",")?;
6658    writeln!(code, "                    \"content\": generated")?;
6659    writeln!(code, "                }},")?;
6660    writeln!(code, "                \"finish_reason\": \"stop\"")?;
6661    writeln!(code, "            }}],")?;
6662    writeln!(code, "            \"usage\": {{")?;
6663    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
6664    writeln!(
6665        code,
6666        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
6667    )?;
6668    writeln!(
6669        code,
6670        "                \"generation_tokens\": generated_count,"
6671    )?;
6672    writeln!(
6673        code,
6674        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
6675    )?;
6676    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
6677    writeln!(code, "            }}")?;
6678    writeln!(code, "        }});")?;
6679    writeln!(
6680        code,
6681        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
6682    )?;
6683    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6684    writeln!(code, "        request.respond(resp).ok();")?;
6685    writeln!(code, "    }}")?;
6686    writeln!(code, "}}")?;
6687
6688    Ok(code)
6689}
6690
6691// ---------------------------------------------------------------------------
6692// Tests
6693// ---------------------------------------------------------------------------
6694
6695#[cfg(test)]
6696mod tests {
6697    use super::*;
6698    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
6699
6700    fn minimal_config() -> ModelConfig {
6701        ModelConfig {
6702            architecture: Architecture::Llama,
6703            hidden_size: 64,
6704            intermediate_size: 128,
6705            num_layers: 2,
6706            num_attention_heads: 4,
6707            num_kv_heads: 4,
6708            head_dim: 16,
6709            vocab_size: 256,
6710            max_seq_len: 512,
6711            rms_norm_eps: 1e-5,
6712            rope_theta: 10000.0,
6713            dtype: DType::F32,
6714            sliding_window_size: None,
6715            qkv_bias: false,
6716        }
6717    }
6718
6719    fn minimal_graph() -> Graph {
6720        Graph::new("test-metal").with_config(minimal_config())
6721    }
6722
6723    #[test]
6724    fn generate_metal_project_creates_files() {
6725        let dir = tempfile::tempdir().unwrap();
6726        let graph = minimal_graph();
6727        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
6728
6729        assert!(
6730            dir.path().join("Cargo.toml").exists(),
6731            "Cargo.toml should be created"
6732        );
6733        assert!(
6734            dir.path().join("src/model.rs").exists(),
6735            "src/model.rs should be created"
6736        );
6737        assert!(
6738            dir.path().join("src/main.rs").exists(),
6739            "src/main.rs should be created"
6740        );
6741        assert!(
6742            dir.path().join("shaders/kernels.metal").exists(),
6743            "shaders/kernels.metal should be created"
6744        );
6745    }
6746
6747    #[test]
6748    fn generated_cargo_toml_has_metal_dep() {
6749        let toml = generate_cargo_toml("my-model");
6750        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
6751        assert!(
6752            toml.contains("tokenizers"),
6753            "Cargo.toml should depend on tokenizers"
6754        );
6755        assert!(
6756            toml.contains("memmap2"),
6757            "Cargo.toml should depend on memmap2"
6758        );
6759        assert!(toml.contains("half"), "Cargo.toml should depend on half");
6760    }
6761
6762    #[test]
6763    fn generated_model_rs_contains_metal_code() {
6764        let config = minimal_config();
6765        let model_rs = generate_model_rs(&config).unwrap();
6766
6767        assert!(
6768            model_rs.contains("pub struct MetalModel"),
6769            "model.rs should define MetalModel struct"
6770        );
6771        assert!(
6772            model_rs.contains("matmul_pipeline: ComputePipelineState"),
6773            "MetalModel should have matmul_pipeline field"
6774        );
6775        assert!(
6776            model_rs.contains("Device::system_default()"),
6777            "model.rs should use Metal device"
6778        );
6779        assert!(
6780            model_rs.contains("new_library_with_source"),
6781            "model.rs should compile Metal shaders"
6782        );
6783        assert!(
6784            model_rs.contains("fn new(weights: &[u8])"),
6785            "MetalModel should implement new()"
6786        );
6787        assert!(
6788            model_rs.contains("fn forward(&mut self, token_id: u32)"),
6789            "MetalModel should implement forward()"
6790        );
6791    }
6792
6793    #[test]
6794    fn generated_shaders_contain_kernel_names() {
6795        let shaders = generate_metal_shaders(&minimal_config());
6796
6797        assert!(
6798            shaders.contains("kernel void matmul_vec"),
6799            "shaders should contain matmul_vec kernel"
6800        );
6801        assert!(
6802            shaders.contains("kernel void rms_norm"),
6803            "shaders should contain rms_norm kernel"
6804        );
6805        assert!(
6806            shaders.contains("kernel void rope"),
6807            "shaders should contain rope kernel"
6808        );
6809        assert!(
6810            shaders.contains("kernel void softmax"),
6811            "shaders should contain softmax kernel"
6812        );
6813        assert!(
6814            shaders.contains("kernel void silu_mul("),
6815            "shaders should contain silu_mul kernel"
6816        );
6817        assert!(
6818            shaders.contains("kernel void silu_mul_fused"),
6819            "shaders should contain silu_mul_fused kernel"
6820        );
6821        assert!(
6822            shaders.contains("kernel void elementwise_add"),
6823            "shaders should contain elementwise_add kernel"
6824        );
6825        assert!(
6826            shaders.contains("kernel void attention"),
6827            "shaders should contain attention kernel"
6828        );
6829        assert!(
6830            shaders.contains("kernel void add_inplace"),
6831            "shaders should contain add_inplace kernel"
6832        );
6833        assert!(
6834            shaders.contains("kernel void copy_buffer"),
6835            "shaders should contain copy_buffer kernel"
6836        );
6837        assert!(
6838            shaders.contains("kernel void copy_offset"),
6839            "shaders should contain copy_offset kernel"
6840        );
6841    }
6842
6843    #[test]
6844    fn generated_shaders_use_simdgroup_features() {
6845        let shaders = generate_metal_shaders(&minimal_config());
6846
6847        assert!(
6848            shaders.contains("threadgroup_barrier"),
6849            "shaders should use threadgroup barriers"
6850        );
6851        assert!(
6852            shaders.contains("threadgroup float"),
6853            "shaders should use threadgroup shared memory"
6854        );
6855        assert!(
6856            shaders.contains("thread_index_in_threadgroup"),
6857            "shaders should use threadgroup indexing"
6858        );
6859        assert!(
6860            shaders.contains("simd_sum"),
6861            "shaders should use simd_sum for warp-level reduction"
6862        );
6863        assert!(
6864            shaders.contains("simd_max"),
6865            "attention kernel should use simd_max for cooperative softmax"
6866        );
6867        assert!(
6868            shaders.contains("thread_index_in_simdgroup"),
6869            "shaders should use simdgroup lane indexing"
6870        );
6871        assert!(
6872            shaders.contains("simdgroup_index_in_threadgroup"),
6873            "shaders should use simdgroup indexing within threadgroup"
6874        );
6875        assert!(
6876            shaders.contains("float4"),
6877            "matmul_vec should use float4 vectorized loads"
6878        );
6879    }
6880
6881    #[test]
6882    fn generated_main_rs_has_tokenizer_usage() {
6883        let config = minimal_config();
6884        let main_rs = generate_main_rs("test-model", &config).unwrap();
6885
6886        assert!(
6887            main_rs.contains("tokenizers::Tokenizer"),
6888            "main.rs should use tokenizers crate"
6889        );
6890        assert!(
6891            main_rs.contains("MetalModel::new"),
6892            "main.rs should call MetalModel::new"
6893        );
6894        assert!(
6895            main_rs.contains("model.forward"),
6896            "main.rs should call model.forward"
6897        );
6898        assert!(
6899            main_rs.contains("memmap2"),
6900            "main.rs should use memmap2 for zero-copy weight loading"
6901        );
6902    }
6903
6904    #[test]
6905    fn missing_config_returns_error() {
6906        let dir = tempfile::tempdir().unwrap();
6907        let graph = Graph::new("no-config");
6908        let result = generate_metal_project(&graph, dir.path(), "fail");
6909        assert!(
6910            matches!(result, Err(MetalCodegenError::MissingConfig)),
6911            "should fail with MissingConfig when graph has no config"
6912        );
6913    }
6914
6915    #[test]
6916    fn sanitize_name_works() {
6917        assert_eq!(sanitize_name("My Model!"), "my-model");
6918        assert_eq!(sanitize_name("test_model"), "test-model");
6919        assert_eq!(sanitize_name("simple"), "simple");
6920    }
6921
6922    #[test]
6923    fn generated_forward_uses_single_command_buffer() {
6924        let config = minimal_config();
6925        let model_rs = generate_model_rs(&config).unwrap();
6926
6927        // The forward function should create exactly one command buffer.
6928        // Use the exact signature to avoid matching forward_prefill/forward_profile.
6929        let forward_start = model_rs
6930            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6931            .unwrap();
6932        let forward_body = &model_rs[forward_start..];
6933        // End at the next pub/private method
6934        let forward_end = forward_body
6935            .find("\n    pub fn forward_profile")
6936            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
6937            .or_else(|| forward_body.find("\n    fn dispatch_"))
6938            .unwrap_or(forward_body.len());
6939        let forward_code = &forward_body[..forward_end];
6940
6941        // Should have exactly one new_command_buffer call
6942        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
6943        assert_eq!(
6944            cmd_buf_count, 1,
6945            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
6946        );
6947
6948        // Should have exactly one commit call
6949        let commit_count = forward_code.matches("cmd.commit()").count();
6950        assert_eq!(
6951            commit_count, 1,
6952            "forward() should commit exactly once, found {commit_count}"
6953        );
6954
6955        // Should wait: once for cmd + possibly once for prev_cmd drain
6956        let wait_count = forward_code.matches("wait_until_completed()").count();
6957        assert!(
6958            wait_count >= 1 && wait_count <= 2,
6959            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
6960        );
6961    }
6962
6963    #[test]
6964    fn generated_model_has_preallocated_working_buffers() {
6965        let config = minimal_config();
6966        let model_rs = generate_model_rs(&config).unwrap();
6967
6968        for buf_name in &[
6969            "normed_buf",
6970            "qkv_buf",
6971            "attn_out_buf",
6972            "attn_proj_buf",
6973            "gate_up_buf",
6974            "ffn_hidden_buf",
6975            "ffn_out_buf",
6976            "add_tmp_buf",
6977        ] {
6978            assert!(
6979                model_rs.contains(&format!("{buf_name}: Buffer")),
6980                "MetalModel should have pre-allocated {buf_name} field"
6981            );
6982        }
6983    }
6984
6985    #[test]
6986    fn generated_dispatch_helpers_take_compute_encoder_ref() {
6987        let config = minimal_config();
6988        let model_rs = generate_model_rs(&config).unwrap();
6989
6990        for method in &[
6991            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
6992            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
6993            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
6994            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
6995            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
6996            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
6997            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
6998            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
6999            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7000            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7001            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7002            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7003            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7004        ] {
7005            assert!(
7006                model_rs.contains(method),
7007                "model.rs should contain dispatch helper: {method}"
7008            );
7009        }
7010    }
7011
7012    #[test]
7013    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7014        let config = minimal_config();
7015        let model_rs = generate_model_rs(&config).unwrap();
7016
7017        // Find dispatch helpers section and check none create their own encoders
7018        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7019        let helpers_code = &model_rs[helpers_start..];
7020
7021        // None of the dispatch_ helpers should call new_command_buffer
7022        assert!(
7023            !helpers_code.contains("self.queue.new_command_buffer()"),
7024            "dispatch helpers should not create their own command buffers"
7025        );
7026
7027        // None should create their own compute encoders
7028        assert!(
7029            !helpers_code.contains("new_compute_command_encoder()"),
7030            "dispatch helpers should not create their own compute encoders"
7031        );
7032
7033        // None should call end_encoding
7034        assert!(
7035            !helpers_code.contains("end_encoding()"),
7036            "dispatch helpers should not call end_encoding"
7037        );
7038
7039        // None should call commit or wait
7040        assert!(
7041            !helpers_code.contains(".commit()"),
7042            "dispatch helpers should not commit command buffers"
7043        );
7044        assert!(
7045            !helpers_code.contains("wait_until_completed"),
7046            "dispatch helpers should not wait on command buffers"
7047        );
7048    }
7049
7050    #[test]
7051    fn generated_forward_batches_compute_encoders() {
7052        let config = minimal_config();
7053        let model_rs = generate_model_rs(&config).unwrap();
7054
7055        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
7056        let forward_start = model_rs
7057            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7058            .unwrap();
7059        let forward_body = &model_rs[forward_start..];
7060        let forward_end = forward_body
7061            .find("\n    pub fn forward_profile")
7062            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7063            .or_else(|| forward_body.find("\n    fn dispatch_"))
7064            .unwrap_or(forward_body.len());
7065        let forward_code = &forward_body[..forward_end];
7066
7067        // Forward should not allocate new buffers
7068        assert!(
7069            !forward_code.contains("device.new_buffer"),
7070            "forward() should not allocate new buffers per call"
7071        );
7072
7073        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
7074        // Copy operations use compute copy kernels instead of blit encoders.
7075        let compute_encoder_count = forward_code
7076            .matches("new_compute_command_encoder()")
7077            .count();
7078        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7079
7080        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
7081        assert_eq!(
7082            compute_encoder_count, 1,
7083            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7084        );
7085        assert_eq!(
7086            blit_encoder_count, 0,
7087            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7088        );
7089    }
7090
7091    #[test]
7092    fn generated_forward_uses_add_inplace() {
7093        let config = minimal_config();
7094        let model_rs = generate_model_rs(&config).unwrap();
7095
7096        // Should use in-place add (no blit copy-back needed)
7097        assert!(
7098            model_rs.contains("dispatch_add_inplace"),
7099            "forward() should use dispatch_add_inplace for residual connections"
7100        );
7101        assert!(
7102            model_rs.contains("add_inplace_pipeline"),
7103            "MetalModel should have add_inplace_pipeline"
7104        );
7105    }
7106
7107    fn minimal_q8_config() -> ModelConfig {
7108        ModelConfig {
7109            architecture: Architecture::Llama,
7110            hidden_size: 64,
7111            intermediate_size: 128,
7112            num_layers: 2,
7113            num_attention_heads: 4,
7114            num_kv_heads: 4,
7115            head_dim: 16,
7116            vocab_size: 256,
7117            max_seq_len: 512,
7118            rms_norm_eps: 1e-5,
7119            rope_theta: 10000.0,
7120            dtype: DType::Q8_0,
7121            sliding_window_size: None,
7122            qkv_bias: false,
7123        }
7124    }
7125
7126    #[test]
7127    fn generated_shaders_contain_q8_kernel() {
7128        let shaders = generate_metal_shaders(&minimal_config());
7129
7130        assert!(
7131            shaders.contains("kernel void matmul_vec_q8"),
7132            "shaders should contain matmul_vec_q8 kernel"
7133        );
7134        assert!(
7135            shaders.contains("device const uchar* matrix"),
7136            "matmul_vec_q8 should accept raw Q8_0 bytes"
7137        );
7138        assert!(
7139            shaders.contains("packed_short4"),
7140            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7141        );
7142        assert!(
7143            shaders.contains("as_type<char2>"),
7144            "matmul_vec_q8 should bitcast short lanes to char2"
7145        );
7146        assert!(
7147            shaders.contains("device const half*"),
7148            "matmul_vec_q8 should read f16 scale via half pointer"
7149        );
7150    }
7151
7152    #[test]
7153    fn generated_model_uses_fused_qkv_projections() {
7154        let config = minimal_config();
7155        let model_rs = generate_model_rs(&config).unwrap();
7156
7157        // Should have fused QKV weight in layer buffers
7158        assert!(
7159            model_rs.contains("qkv_weight: Buffer"),
7160            "LayerBuffers should have fused qkv_weight field"
7161        );
7162        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
7163        assert!(
7164            !model_rs.contains("    q_weight: Buffer"),
7165            "LayerBuffers should not have separate q_weight field"
7166        );
7167        assert!(
7168            !model_rs.contains("    k_weight: Buffer"),
7169            "LayerBuffers should not have separate k_weight field"
7170        );
7171        assert!(
7172            !model_rs.contains("    v_weight: Buffer"),
7173            "LayerBuffers should not have separate v_weight field"
7174        );
7175
7176        // Should have fused gate_up_weight
7177        assert!(
7178            model_rs.contains("gate_up_weight: Buffer"),
7179            "LayerBuffers should have fused gate_up_weight field"
7180        );
7181        // Should NOT have separate gate/up weight fields
7182        assert!(
7183            !model_rs.contains("    gate_weight: Buffer"),
7184            "LayerBuffers should not have separate gate_weight field"
7185        );
7186        assert!(
7187            !model_rs.contains("    up_weight: Buffer"),
7188            "LayerBuffers should not have separate up_weight field"
7189        );
7190
7191        // Should have fused working buffers
7192        assert!(
7193            model_rs.contains("qkv_buf: Buffer"),
7194            "MetalModel should have fused qkv_buf"
7195        );
7196        assert!(
7197            model_rs.contains("gate_up_buf: Buffer"),
7198            "MetalModel should have fused gate_up_buf"
7199        );
7200
7201        // Forward pass should use fused dispatch
7202        assert!(
7203            model_rs.contains("dispatch_silu_mul_fused"),
7204            "forward pass should use dispatch_silu_mul_fused"
7205        );
7206        assert!(
7207            model_rs.contains("dispatch_rope_offset"),
7208            "forward pass should use dispatch_rope_offset for fused QKV"
7209        );
7210        assert!(
7211            model_rs.contains("dispatch_attention_offset"),
7212            "forward pass should use dispatch_attention_offset for fused QKV"
7213        );
7214    }
7215
7216    #[test]
7217    fn q8_model_has_matmul_q8_pipeline() {
7218        let config = minimal_q8_config();
7219        let model_rs = generate_model_rs(&config).unwrap();
7220
7221        assert!(
7222            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7223            "MetalModel should have matmul_q8_pipeline field"
7224        );
7225        assert!(
7226            model_rs.contains("matmul_q8_pipeline,"),
7227            "MetalModel Self should include matmul_q8_pipeline"
7228        );
7229    }
7230
7231    #[test]
7232    fn q8_model_uses_dispatch_matmul_q8() {
7233        let config = minimal_q8_config();
7234        let model_rs = generate_model_rs(&config).unwrap();
7235
7236        assert!(
7237            model_rs.contains("dispatch_matmul_q8"),
7238            "Q8_0 model should use dispatch_matmul_q8 for projections"
7239        );
7240        assert!(
7241            model_rs.contains("fn dispatch_matmul_q8"),
7242            "model.rs should define dispatch_matmul_q8 method"
7243        );
7244    }
7245
7246    #[test]
7247    fn q8_model_loads_raw_bytes_not_dequantized() {
7248        let config = minimal_q8_config();
7249        let model_rs = generate_model_rs(&config).unwrap();
7250
7251        // Should NOT contain dequantization code
7252        assert!(
7253            !model_rs.contains("f16_to_f32"),
7254            "Q8_0 model should not dequantize weights to f32"
7255        );
7256        assert!(
7257            !model_rs.contains("f32_data"),
7258            "Q8_0 model should not create f32 weight data"
7259        );
7260
7261        // Should load raw Q8_0 bytes directly
7262        assert!(
7263            model_rs.contains("total_raw as u64"),
7264            "Q8_0 model should load raw bytes into Metal buffer"
7265        );
7266    }
7267
7268    #[test]
7269    fn q8_model_norms_stay_f32() {
7270        let config = minimal_q8_config();
7271        let model_rs = generate_model_rs(&config).unwrap();
7272
7273        // Norm weights should still use f32 buffers
7274        assert!(
7275            model_rs.contains("let attn_norm = next_f32_buffer"),
7276            "attn_norm should use f32 buffer even for Q8_0 models"
7277        );
7278        assert!(
7279            model_rs.contains("let ffn_norm = next_f32_buffer"),
7280            "ffn_norm should use f32 buffer even for Q8_0 models"
7281        );
7282        assert!(
7283            model_rs.contains("let norm_buf = next_f32_buffer"),
7284            "final norm should use f32 buffer even for Q8_0 models"
7285        );
7286    }
7287
7288    #[test]
7289    fn q8_model_uses_fused_weight_loading() {
7290        let config = minimal_q8_config();
7291        let model_rs = generate_model_rs(&config).unwrap();
7292
7293        // Should use fused Q8 buffer loading for QKV
7294        assert!(
7295            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7296            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7297        );
7298        // Should use fused Q8 buffer loading for gate+up
7299        assert!(
7300            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7301            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7302        );
7303        // Should still use regular q8 buffer for individual weights
7304        assert!(
7305            model_rs.contains("let o_weight = next_q8_buffer"),
7306            "Q8_0 model should use next_q8_buffer for O weight"
7307        );
7308        assert!(
7309            model_rs.contains("let down_weight = next_q8_buffer"),
7310            "Q8_0 model should use next_q8_buffer for down weight"
7311        );
7312    }
7313
7314    #[test]
7315    fn f32_model_does_not_use_q8_dispatch() {
7316        let config = minimal_config();
7317        let model_rs = generate_model_rs(&config).unwrap();
7318
7319        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
7320        let forward_start = model_rs
7321            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7322            .unwrap();
7323        let forward_body = &model_rs[forward_start..];
7324        let forward_end = forward_body
7325            .find("\n    fn dispatch_")
7326            .unwrap_or(forward_body.len());
7327        let forward_code = &forward_body[..forward_end];
7328
7329        assert!(
7330            !forward_code.contains("dispatch_matmul_q8"),
7331            "f32 model forward should not use dispatch_matmul_q8"
7332        );
7333    }
7334
7335    #[test]
7336    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7337        let config = minimal_q8_config();
7338        let model_rs = generate_model_rs(&config).unwrap();
7339
7340        assert!(
7341            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7342            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7343        );
7344    }
7345
7346    #[test]
7347    fn generated_model_has_double_buffered_prefill() {
7348        let config = minimal_config();
7349        let model_rs = generate_model_rs(&config).unwrap();
7350
7351        // MetalModel should have prev_cmd field for double-buffered prefill
7352        assert!(
7353            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7354            "MetalModel should have prev_cmd field for double-buffered prefill"
7355        );
7356
7357        // Should have forward_prefill method
7358        assert!(
7359            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7360            "MetalModel should have forward_prefill method"
7361        );
7362
7363        // forward() should drain prev_cmd at the start
7364        assert!(
7365            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7366            "forward() should drain prev_cmd from previous prefill"
7367        );
7368    }
7369
7370    #[test]
7371    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7372        let config = minimal_config();
7373        let main_rs = generate_main_rs("test-model", &config).unwrap();
7374
7375        assert!(
7376            main_rs.contains("forward_prefill"),
7377            "main.rs should use forward_prefill for intermediate prompt tokens"
7378        );
7379        assert!(
7380            main_rs.contains("double-buffered"),
7381            "main.rs should document double-buffered prefill"
7382        );
7383    }
7384
7385    #[test]
7386    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7387        let shaders = generate_metal_shaders(&minimal_config());
7388
7389        assert!(
7390            shaders.contains("packed_short4"),
7391            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7392        );
7393        assert!(
7394            shaders.contains("d0[0]"),
7395            "matmul_vec_q8 should index the wide pointer for row 0"
7396        );
7397        assert!(
7398            shaders.contains("as_type<char2>"),
7399            "matmul_vec_q8 should bitcast short lanes to char2"
7400        );
7401        assert!(
7402            shaders.contains("dot("),
7403            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7404        );
7405    }
7406
7407    // ── Q4_0 tests ──────────────────────────────────────────────────────
7408
7409    fn minimal_q4_config() -> ModelConfig {
7410        ModelConfig {
7411            architecture: Architecture::Llama,
7412            hidden_size: 64,
7413            intermediate_size: 128,
7414            num_layers: 2,
7415            num_attention_heads: 4,
7416            num_kv_heads: 4,
7417            head_dim: 16,
7418            vocab_size: 256,
7419            max_seq_len: 512,
7420            rms_norm_eps: 1e-5,
7421            rope_theta: 10000.0,
7422            dtype: DType::Q4_0,
7423            sliding_window_size: None,
7424            qkv_bias: false,
7425        }
7426    }
7427
7428    #[test]
7429    fn generated_shaders_contain_q4_kernel() {
7430        let shaders = generate_metal_shaders(&minimal_config());
7431
7432        assert!(
7433            shaders.contains("kernel void matmul_vec_q4"),
7434            "shaders should contain matmul_vec_q4 kernel"
7435        );
7436        assert!(
7437            shaders.contains("Q4_ROWS_PER_TG"),
7438            "shaders should define Q4_ROWS_PER_TG constant"
7439        );
7440        assert!(
7441            shaders.contains("Q4_ROWS_PER_SG"),
7442            "shaders should define Q4_ROWS_PER_SG constant"
7443        );
7444    }
7445
7446    #[test]
7447    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7448        let shaders = generate_metal_shaders(&minimal_config());
7449
7450        // Q4_0 kernel should use uchar4 for packed byte loads
7451        assert!(
7452            shaders.contains("uchar4"),
7453            "matmul_vec_q4 should use uchar4 for packed byte loads"
7454        );
7455        // Should unpack nibbles with &0xF and >>4
7456        assert!(
7457            shaders.contains("&0xF"),
7458            "matmul_vec_q4 should extract low nibble with &0xF"
7459        );
7460        assert!(
7461            shaders.contains(">>4"),
7462            "matmul_vec_q4 should extract high nibble with >>4"
7463        );
7464        // Should subtract 8 to convert unsigned to signed
7465        assert!(
7466            shaders.contains("-8)"),
7467            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7468        );
7469        // Should use 18-byte block size
7470        assert!(
7471            shaders.contains("blk * 18"),
7472            "matmul_vec_q4 should use 18-byte block stride"
7473        );
7474    }
7475
7476    #[test]
7477    fn q4_model_has_matmul_q4_pipeline() {
7478        let config = minimal_q4_config();
7479        let model_rs = generate_model_rs(&config).unwrap();
7480
7481        assert!(
7482            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7483            "MetalModel should have matmul_q4_pipeline field"
7484        );
7485        assert!(
7486            model_rs.contains("matmul_q4_pipeline,"),
7487            "MetalModel Self should include matmul_q4_pipeline"
7488        );
7489    }
7490
7491    #[test]
7492    fn q4_model_uses_dispatch_matmul_q4() {
7493        let config = minimal_q4_config();
7494        let model_rs = generate_model_rs(&config).unwrap();
7495
7496        assert!(
7497            model_rs.contains("dispatch_matmul_q4"),
7498            "Q4_0 model should use dispatch_matmul_q4 for projections"
7499        );
7500        assert!(
7501            model_rs.contains("fn dispatch_matmul_q4"),
7502            "model.rs should define dispatch_matmul_q4 method"
7503        );
7504    }
7505
7506    #[test]
7507    fn q4_model_loads_raw_bytes_not_dequantized() {
7508        let config = minimal_q4_config();
7509        let model_rs = generate_model_rs(&config).unwrap();
7510
7511        // Should NOT contain dequantization code
7512        assert!(
7513            !model_rs.contains("f16_to_f32"),
7514            "Q4_0 model should not dequantize weights to f32"
7515        );
7516
7517        // Should load raw Q4_0 bytes directly
7518        assert!(
7519            model_rs.contains("total_raw as u64"),
7520            "Q4_0 model should load raw bytes into Metal buffer"
7521        );
7522    }
7523
7524    #[test]
7525    fn q4_model_norms_stay_f32() {
7526        let config = minimal_q4_config();
7527        let model_rs = generate_model_rs(&config).unwrap();
7528
7529        assert!(
7530            model_rs.contains("let attn_norm = next_f32_buffer"),
7531            "attn_norm should use f32 buffer even for Q4_0 models"
7532        );
7533        assert!(
7534            model_rs.contains("let ffn_norm = next_f32_buffer"),
7535            "ffn_norm should use f32 buffer even for Q4_0 models"
7536        );
7537        assert!(
7538            model_rs.contains("let norm_buf = next_f32_buffer"),
7539            "final norm should use f32 buffer even for Q4_0 models"
7540        );
7541    }
7542
7543    #[test]
7544    fn q4_model_uses_fused_weight_loading() {
7545        let config = minimal_q4_config();
7546        let model_rs = generate_model_rs(&config).unwrap();
7547
7548        assert!(
7549            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7550            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7551        );
7552        assert!(
7553            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7554            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7555        );
7556        assert!(
7557            model_rs.contains("let o_weight = next_q4_buffer"),
7558            "Q4_0 model should use next_q4_buffer for O weight"
7559        );
7560        assert!(
7561            model_rs.contains("let down_weight = next_q4_buffer"),
7562            "Q4_0 model should use next_q4_buffer for down weight"
7563        );
7564    }
7565
7566    #[test]
7567    fn attention_flash_batch_kernel_exists() {
7568        // The flash kernel is still wired into the library (pipeline, kernel
7569        // source).  Dispatch is currently routed to the legacy path pending a
7570        // fix for a numerical issue discovered after the prompt-chunking fix.
7571        let config = minimal_config();
7572        let model_rs = generate_model_rs(&config).unwrap();
7573        let shaders = generate_metal_shaders(&config);
7574
7575        assert!(
7576            shaders.contains("kernel void attention_flash_batch"),
7577            "shaders.metal must still contain the attention_flash_batch kernel"
7578        );
7579        assert!(
7580            shaders.contains("FLASH_K_TILE"),
7581            "flash kernel must tile K/V with a FLASH_K_TILE constant"
7582        );
7583        assert!(
7584            model_rs.contains("attention_flash_batch_pipeline"),
7585            "MetalModel must register the flash attention pipeline"
7586        );
7587    }
7588
7589    #[test]
7590    fn forward_prefill_batch_chunks_by_max_batch_size() {
7591        // Regression: prior to v0.6.4 forward_prefill_batch truncated prompts
7592        // longer than MAX_BATCH_SIZE (512) tokens via `.min(MAX_BATCH_SIZE)`,
7593        // silently dropping the middle of long prompts.  Must now loop over
7594        // MAX_BATCH_SIZE-sized chunks and carry KV-cache state across them.
7595        let config = minimal_config();
7596        let model_rs = generate_model_rs(&config).unwrap();
7597        assert!(
7598            model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
7599            "forward_prefill_batch must chunk long prompts"
7600        );
7601        assert!(
7602            !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
7603            "the old truncation path must be gone"
7604        );
7605    }
7606
7607    #[test]
7608    fn qwen2_qkv_bias_wired_through_metal_codegen() {
7609        // Issue #210: the pre-v0.6.2 Metal codegen had zero handling for
7610        // qkv_bias.  Verify that a Qwen2-style config emits the bias buffer,
7611        // loader, pipeline, and dispatch call in the expected places.
7612        let config = ModelConfig {
7613            architecture: Architecture::Qwen2,
7614            qkv_bias: true,
7615            ..minimal_config()
7616        };
7617        let model_rs = generate_model_rs(&config).unwrap();
7618
7619        assert!(
7620            model_rs.contains("qkv_bias: Buffer"),
7621            "Qwen2 LayerBuffers must declare qkv_bias field"
7622        );
7623        assert!(
7624            model_rs.contains("let qkv_bias = next_f32_buffer"),
7625            "Qwen2 layer init must load the bias from the weight blob"
7626        );
7627        assert!(
7628            model_rs.contains("add_bias_batch_pipeline"),
7629            "Qwen2 model struct must include the add_bias_batch_pipeline"
7630        );
7631        assert!(
7632            model_rs.contains("fn dispatch_add_bias_batch"),
7633            "Qwen2 codegen must emit dispatch_add_bias_batch helper"
7634        );
7635        assert!(
7636            model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
7637            "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
7638        );
7639        assert!(
7640            model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
7641            "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
7642        );
7643
7644        // The add_bias_batch MSL kernel must be in the shader source.
7645        let shaders = generate_metal_shaders(&config);
7646        assert!(
7647            shaders.contains("kernel void add_bias_batch"),
7648            "shaders.metal must contain the add_bias_batch kernel"
7649        );
7650    }
7651
7652    #[test]
7653    fn llama_does_not_emit_qkv_bias_machinery() {
7654        // Negative test: non-Qwen2 models must NOT carry the bias dispatch,
7655        // buffer, or pipeline — keeps generated code lean for Llama/Phi/etc.
7656        let config = minimal_config();
7657        assert!(!config.qkv_bias);
7658        let model_rs = generate_model_rs(&config).unwrap();
7659        assert!(
7660            !model_rs.contains("qkv_bias: Buffer"),
7661            "Llama must not have qkv_bias field"
7662        );
7663        assert!(
7664            !model_rs.contains("add_bias_batch_pipeline"),
7665            "Llama must not pull in add_bias_batch_pipeline"
7666        );
7667        assert!(
7668            !model_rs.contains("dispatch_add_bias_batch"),
7669            "Llama must not call dispatch_add_bias_batch"
7670        );
7671    }
7672
7673    #[test]
7674    fn q4_dispatch_helper_takes_compute_encoder_ref() {
7675        let config = minimal_q4_config();
7676        let model_rs = generate_model_rs(&config).unwrap();
7677
7678        assert!(
7679            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
7680            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
7681        );
7682    }
7683
7684    #[test]
7685    fn f32_model_does_not_use_q4_dispatch() {
7686        let config = minimal_config();
7687        let model_rs = generate_model_rs(&config).unwrap();
7688
7689        let forward_start = model_rs
7690            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7691            .unwrap();
7692        let forward_body = &model_rs[forward_start..];
7693        let forward_end = forward_body
7694            .find("\n    fn dispatch_")
7695            .unwrap_or(forward_body.len());
7696        let forward_code = &forward_body[..forward_end];
7697
7698        assert!(
7699            !forward_code.contains("dispatch_matmul_q4"),
7700            "f32 model forward should not use dispatch_matmul_q4"
7701        );
7702    }
7703
7704    #[test]
7705    fn q4_model_lm_head_uses_q4_buffer() {
7706        let config = minimal_q4_config();
7707        let model_rs = generate_model_rs(&config).unwrap();
7708
7709        assert!(
7710            model_rs.contains("let lm_head_buf = next_q4_buffer"),
7711            "Q4_0 model should use next_q4_buffer for lm_head"
7712        );
7713    }
7714
7715    #[test]
7716    fn vec_tile_size_matches_model_dimensions() {
7717        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
7718        let small = minimal_config();
7719        let shaders_small = generate_metal_shaders(&small);
7720        assert!(
7721            shaders_small.contains("vec_tile[128]"),
7722            "vec_tile should be sized to max(hidden, intermediate) = 128"
7723        );
7724
7725        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
7726        let mut large = minimal_config();
7727        large.hidden_size = 2048;
7728        large.intermediate_size = 8192;
7729        let shaders_large = generate_metal_shaders(&large);
7730        assert!(
7731            shaders_large.contains("vec_tile[8192]"),
7732            "vec_tile should be 8192 for models with intermediate=8192"
7733        );
7734        assert!(
7735            !shaders_large.contains("vec_tile[4096]"),
7736            "vec_tile should NOT be hardcoded to 4096"
7737        );
7738    }
7739
7740    #[test]
7741    fn generated_cargo_toml_has_server_deps() {
7742        let toml = generate_cargo_toml("my-model");
7743        assert!(
7744            toml.contains("tiny_http"),
7745            "Cargo.toml should depend on tiny_http"
7746        );
7747        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
7748        assert!(
7749            toml.contains("serde_json"),
7750            "Cargo.toml should depend on serde_json"
7751        );
7752    }
7753
7754    #[test]
7755    fn generated_main_rs_has_serve_mode() {
7756        let config = minimal_config();
7757        let main_rs = generate_main_rs("test-model", &config).unwrap();
7758
7759        assert!(
7760            main_rs.contains("--serve"),
7761            "main.rs should parse --serve flag"
7762        );
7763        assert!(
7764            main_rs.contains("--port"),
7765            "main.rs should parse --port flag"
7766        );
7767        assert!(
7768            main_rs.contains("fn serve("),
7769            "main.rs should define serve function"
7770        );
7771        assert!(
7772            main_rs.contains("tiny_http::Server::http"),
7773            "main.rs should create tiny_http server"
7774        );
7775    }
7776
7777    #[test]
7778    fn generated_main_rs_has_chat_completions_endpoint() {
7779        let config = minimal_config();
7780        let main_rs = generate_main_rs("test-model", &config).unwrap();
7781
7782        assert!(
7783            main_rs.contains("/v1/chat/completions"),
7784            "main.rs should handle /v1/chat/completions endpoint"
7785        );
7786        assert!(
7787            main_rs.contains("/v1/models"),
7788            "main.rs should handle /v1/models endpoint"
7789        );
7790        assert!(
7791            main_rs.contains("/health"),
7792            "main.rs should handle /health endpoint"
7793        );
7794    }
7795
7796    #[test]
7797    fn generated_main_rs_has_sse_streaming() {
7798        let config = minimal_config();
7799        let main_rs = generate_main_rs("test-model", &config).unwrap();
7800
7801        assert!(
7802            main_rs.contains("text/event-stream"),
7803            "main.rs should set SSE content type for streaming"
7804        );
7805        assert!(
7806            main_rs.contains("chat.completion.chunk"),
7807            "main.rs should emit SSE chunks"
7808        );
7809        assert!(
7810            main_rs.contains("[DONE]"),
7811            "main.rs should emit [DONE] sentinel"
7812        );
7813    }
7814
7815    #[test]
7816    fn generated_main_rs_has_chat_message_formatting() {
7817        let config = minimal_config();
7818        let main_rs = generate_main_rs("test-model", &config).unwrap();
7819
7820        assert!(
7821            main_rs.contains("fn format_chat_messages"),
7822            "main.rs should define format_chat_messages function"
7823        );
7824        assert!(
7825            main_rs.contains("<|im_start|>"),
7826            "main.rs should use ChatML format"
7827        );
7828        assert!(
7829            main_rs.contains("<|im_end|>"),
7830            "main.rs should use ChatML format"
7831        );
7832    }
7833
7834    #[test]
7835    fn generated_main_rs_has_request_types() {
7836        let config = minimal_config();
7837        let main_rs = generate_main_rs("test-model", &config).unwrap();
7838
7839        assert!(
7840            main_rs.contains("struct ChatRequest"),
7841            "main.rs should define ChatRequest struct"
7842        );
7843        assert!(
7844            main_rs.contains("struct ChatMessage"),
7845            "main.rs should define ChatMessage struct"
7846        );
7847        assert!(
7848            main_rs.contains("Deserialize"),
7849            "main.rs should derive Deserialize for request types"
7850        );
7851    }
7852
7853    #[test]
7854    fn generated_model_has_reset_method() {
7855        let config = minimal_config();
7856        let model_rs = generate_model_rs(&config).unwrap();
7857
7858        assert!(
7859            model_rs.contains("pub fn reset(&mut self)"),
7860            "model.rs should have a reset() method for multi-request serving"
7861        );
7862        assert!(
7863            model_rs.contains("self.pos = 0"),
7864            "reset() should reset position to 0"
7865        );
7866    }
7867
7868    #[test]
7869    fn generated_main_rs_cli_mode_still_works() {
7870        let config = minimal_config();
7871        let main_rs = generate_main_rs("test-model", &config).unwrap();
7872
7873        // CLI mode should still be functional
7874        assert!(
7875            main_rs.contains("fn cli_mode("),
7876            "main.rs should define cli_mode function"
7877        );
7878        assert!(
7879            main_rs.contains("model.forward"),
7880            "main.rs should call model.forward"
7881        );
7882        assert!(
7883            main_rs.contains("model.forward_prefill"),
7884            "main.rs should call model.forward_prefill"
7885        );
7886    }
7887
7888    // ── Batched prefill tests ──────────────────────────────────────────
7889
7890    #[test]
7891    fn generated_shaders_contain_batch_kernels() {
7892        let shaders = generate_metal_shaders(&minimal_config());
7893
7894        assert!(
7895            shaders.contains("kernel void matmul_vec_batch"),
7896            "shaders should contain matmul_vec_batch kernel"
7897        );
7898        assert!(
7899            shaders.contains("kernel void matmul_vec_q8_batch"),
7900            "shaders should contain matmul_vec_q8_batch kernel"
7901        );
7902        assert!(
7903            shaders.contains("kernel void matmul_q8_gemm_batch"),
7904            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
7905        );
7906        assert!(
7907            shaders.contains("kernel void matmul_vec_q4_batch"),
7908            "shaders should contain matmul_vec_q4_batch kernel"
7909        );
7910        assert!(
7911            shaders.contains("kernel void rms_norm_batch"),
7912            "shaders should contain rms_norm_batch kernel"
7913        );
7914        assert!(
7915            shaders.contains("kernel void silu_mul_fused_batch"),
7916            "shaders should contain silu_mul_fused_batch kernel"
7917        );
7918        assert!(
7919            shaders.contains("kernel void add_inplace_batch"),
7920            "shaders should contain add_inplace_batch kernel"
7921        );
7922        assert!(
7923            shaders.contains("kernel void copy_embedding_batch"),
7924            "shaders should contain copy_embedding_batch kernel"
7925        );
7926    }
7927
7928    #[test]
7929    fn generated_model_has_batch_pipelines() {
7930        let config = minimal_config();
7931        let model_rs = generate_model_rs(&config).unwrap();
7932
7933        for pipeline in &[
7934            "matmul_batch_pipeline",
7935            "matmul_q8_batch_pipeline",
7936            "matmul_q8_gemm_batch_pipeline",
7937            "matmul_q4_batch_pipeline",
7938            "rms_norm_batch_pipeline",
7939            "rope_batch_pipeline",
7940            "silu_mul_fused_batch_pipeline",
7941            "add_inplace_batch_pipeline",
7942            "copy_embedding_batch_pipeline",
7943            "attention_batch_pipeline",
7944            "copy_kv_batch_pipeline",
7945        ] {
7946            assert!(
7947                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
7948                "MetalModel should have {pipeline} field"
7949            );
7950        }
7951    }
7952
7953    #[test]
7954    fn generated_model_has_batch_buffers() {
7955        let config = minimal_config();
7956        let model_rs = generate_model_rs(&config).unwrap();
7957
7958        for buf in &[
7959            "batch_hidden_buf",
7960            "batch_residual_buf",
7961            "batch_qkv_buf",
7962            "batch_attn_out_buf",
7963            "batch_attn_proj_buf",
7964            "batch_gate_up_buf",
7965            "batch_ffn_hidden_buf",
7966            "batch_ffn_out_buf",
7967            "batch_tokens_buf",
7968            "batch_positions_buf",
7969        ] {
7970            assert!(
7971                model_rs.contains(&format!("{buf}: Buffer")),
7972                "MetalModel should have {buf} field"
7973            );
7974        }
7975    }
7976
7977    #[test]
7978    fn generated_model_has_forward_prefill_batch() {
7979        let config = minimal_config();
7980        let model_rs = generate_model_rs(&config).unwrap();
7981
7982        assert!(
7983            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
7984            "MetalModel should have forward_prefill_batch method"
7985        );
7986
7987        // forward_prefill should delegate to forward_prefill_batch
7988        assert!(
7989            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
7990            "forward_prefill should delegate to forward_prefill_batch"
7991        );
7992    }
7993
7994    #[test]
7995    fn generated_model_has_max_batch_size_constant() {
7996        let config = minimal_config();
7997        let model_rs = generate_model_rs(&config).unwrap();
7998
7999        assert!(
8000            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8001            "model.rs should define MAX_BATCH_SIZE constant"
8002        );
8003    }
8004
8005    #[test]
8006    fn forward_prefill_batch_uses_batch_dispatch() {
8007        let config = minimal_config();
8008        let model_rs = generate_model_rs(&config).unwrap();
8009
8010        let batch_start = model_rs
8011            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8012            .unwrap();
8013        let batch_body = &model_rs[batch_start..];
8014        let batch_end = batch_body
8015            .find("\n    pub fn reset")
8016            .unwrap_or(batch_body.len());
8017        let batch_code = &batch_body[..batch_end];
8018
8019        // Should use batched dispatch methods
8020        assert!(
8021            batch_code.contains("dispatch_rms_norm_batch"),
8022            "forward_prefill_batch should use dispatch_rms_norm_batch"
8023        );
8024        assert!(
8025            batch_code.contains("dispatch_copy_embedding_batch"),
8026            "forward_prefill_batch should use dispatch_copy_embedding_batch"
8027        );
8028        assert!(
8029            batch_code.contains("dispatch_silu_mul_fused_batch"),
8030            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8031        );
8032        // Should use batched causal attention dispatch
8033        assert!(
8034            batch_code.contains("dispatch_attention_batch"),
8035            "forward_prefill_batch should use dispatch_attention_batch"
8036        );
8037        // Should use fused KV cache copy (both K and V in one dispatch)
8038        assert!(
8039            batch_code.contains("dispatch_copy_kv_both_batch"),
8040            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8041        );
8042        // Should use fused RoPE Q+K dispatch
8043        assert!(
8044            batch_code.contains("dispatch_rope_qk_batch"),
8045            "forward_prefill_batch should use dispatch_rope_qk_batch"
8046        );
8047    }
8048
8049    #[test]
8050    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8051        let config = minimal_q8_config();
8052        let model_rs = generate_model_rs(&config).unwrap();
8053
8054        let batch_start = model_rs
8055            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8056            .unwrap();
8057        let batch_body = &model_rs[batch_start..];
8058        let batch_end = batch_body
8059            .find("\n    pub fn reset")
8060            .unwrap_or(batch_body.len());
8061        let batch_code = &batch_body[..batch_end];
8062
8063        assert!(
8064            batch_code.contains("dispatch_matmul_q8_batch"),
8065            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8066        );
8067    }
8068
8069    #[test]
8070    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8071        let config = minimal_q4_config();
8072        let model_rs = generate_model_rs(&config).unwrap();
8073
8074        let batch_start = model_rs
8075            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8076            .unwrap();
8077        let batch_body = &model_rs[batch_start..];
8078        let batch_end = batch_body
8079            .find("\n    pub fn reset")
8080            .unwrap_or(batch_body.len());
8081        let batch_code = &batch_body[..batch_end];
8082
8083        assert!(
8084            batch_code.contains("dispatch_matmul_q4_batch"),
8085            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8086        );
8087    }
8088
8089    #[test]
8090    fn generated_main_rs_uses_batched_prefill() {
8091        let config = minimal_config();
8092        let main_rs = generate_main_rs("test-model", &config).unwrap();
8093
8094        assert!(
8095            main_rs.contains("forward_prefill_batch"),
8096            "main.rs should use forward_prefill_batch for prompt tokens"
8097        );
8098    }
8099
8100    #[test]
8101    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8102        let config = minimal_config();
8103        let model_rs = generate_model_rs(&config).unwrap();
8104
8105        let batch_start = model_rs
8106            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8107            .unwrap();
8108        let batch_body = &model_rs[batch_start..];
8109        let batch_end = batch_body
8110            .find("\n    pub fn reset")
8111            .unwrap_or(batch_body.len());
8112        let batch_code = &batch_body[..batch_end];
8113
8114        assert!(
8115            batch_code.contains("dispatch_matmul_batch"),
8116            "f32 forward_prefill_batch should use dispatch_matmul_batch"
8117        );
8118        // Should NOT use Q8 or Q4 batch dispatch
8119        assert!(
8120            !batch_code.contains("dispatch_matmul_q8_batch"),
8121            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8122        );
8123        assert!(
8124            !batch_code.contains("dispatch_matmul_q4_batch"),
8125            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8126        );
8127    }
8128
8129    #[test]
8130    fn forward_uses_cpu_embedding_lookup() {
8131        let config = minimal_config();
8132        let model_rs = generate_model_rs(&config).unwrap();
8133
8134        // Find just the forward() body (not forward_profile)
8135        let forward_start = model_rs
8136            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8137            .unwrap();
8138        let forward_body = &model_rs[forward_start..];
8139        let forward_end = forward_body
8140            .find("\n    pub fn forward_profile")
8141            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8142            .unwrap_or(forward_body.len());
8143        let forward_code = &forward_body[..forward_end];
8144
8145        // forward() should use CPU memcpy for embedding lookup (unified memory)
8146        assert!(
8147            forward_code.contains("embed_buf.contents()"),
8148            "forward() should access embed_buf via CPU unified memory for embedding lookup"
8149        );
8150        assert!(
8151            forward_code.contains("copy_nonoverlapping"),
8152            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8153        );
8154        // forward() should NOT use GPU dispatch for embedding
8155        assert!(
8156            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8157            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8158        );
8159    }
8160
8161    #[test]
8162    fn forward_profile_method_exists() {
8163        let config = minimal_config();
8164        let model_rs = generate_model_rs(&config).unwrap();
8165
8166        assert!(
8167            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8168            "MetalModel should have forward_profile() method"
8169        );
8170        // Profile method should print timing information
8171        assert!(
8172            model_rs.contains("[profile]"),
8173            "forward_profile() should print timing with [profile] prefix"
8174        );
8175        assert!(
8176            model_rs.contains("d_embed"),
8177            "forward_profile() should measure embedding time"
8178        );
8179        assert!(
8180            model_rs.contains("d_layers"),
8181            "forward_profile() should measure layer time"
8182        );
8183        assert!(
8184            model_rs.contains("d_logits"),
8185            "forward_profile() should measure logits time"
8186        );
8187    }
8188
8189    #[test]
8190    fn generated_cli_has_profile_flag() {
8191        let config = minimal_config();
8192        let main_rs = generate_main_rs("test-model", &config).unwrap();
8193
8194        assert!(
8195            main_rs.contains("--profile"),
8196            "CLI should support --profile flag"
8197        );
8198        assert!(
8199            main_rs.contains("forward_profile"),
8200            "CLI should call forward_profile when --profile is set"
8201        );
8202    }
8203
8204    #[test]
8205    fn generated_cli_has_thermal_yield() {
8206        let config = minimal_config();
8207        let main_rs = generate_main_rs("test-model", &config).unwrap();
8208
8209        assert!(
8210            main_rs.contains("yield_now()"),
8211            "CLI generation loop should include thread::yield_now() for thermal management"
8212        );
8213    }
8214
8215    // ── Real-world validation tests ──────────────────────────────────────
8216
8217    #[test]
8218    fn generated_forward_handles_single_token_prompt() {
8219        // With a single token (the first prompt token), forward() should work
8220        // at pos=0 where seq_len=1. The attention kernel must handle the case
8221        // where there is only one KV entry (no prefill context).
8222        let config = minimal_config();
8223        let model_rs = generate_model_rs(&config).unwrap();
8224
8225        // The forward function should accept any u32 token_id (no minimum pos guard)
8226        let forward_start = model_rs
8227            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8228            .expect("forward() must exist");
8229        let forward_body = &model_rs[forward_start..forward_start + 400];
8230
8231        // Should NOT require pos > 0 or seq_len > 1
8232        assert!(
8233            !forward_body.contains("assert!(self.pos > 0"),
8234            "forward() must accept pos=0 (first token with no prefill)"
8235        );
8236
8237        // The attention kernel should handle seq_len=1 via the pos field
8238        assert!(
8239            model_rs.contains("self.pos"),
8240            "forward() should use self.pos to track sequence position"
8241        );
8242    }
8243
8244    #[test]
8245    fn generated_reset_clears_kv_cache_position() {
8246        // After reset(), the model should be in a clean state. The pos field
8247        // must be 0 so new generation starts from scratch.
8248        let config = minimal_config();
8249        let model_rs = generate_model_rs(&config).unwrap();
8250
8251        let reset_start = model_rs
8252            .find("pub fn reset(&mut self)")
8253            .expect("reset() must exist");
8254        let reset_body = &model_rs[reset_start..reset_start + 200];
8255
8256        // Reset must zero the position counter
8257        assert!(
8258            reset_body.contains("self.pos = 0"),
8259            "reset() must set self.pos = 0"
8260        );
8261
8262        // Verify reset clears prev_cmd (double-buffering state)
8263        assert!(
8264            reset_body.contains("self.prev_cmd = None"),
8265            "reset() should clear prev_cmd for clean command buffer state"
8266        );
8267    }
8268
8269    #[test]
8270    fn generated_serve_handles_empty_messages_gracefully() {
8271        // The serve endpoint should not crash when receiving an empty messages array.
8272        // The format_chat_messages function should handle this gracefully.
8273        let config = minimal_config();
8274        let main_rs = generate_main_rs("test-model", &config).unwrap();
8275
8276        // The format_chat_messages function should exist and handle empty input
8277        let format_fn_start = main_rs
8278            .find("fn format_chat_messages")
8279            .expect("format_chat_messages must exist");
8280        let format_fn_body =
8281            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8282
8283        // It should iterate over messages (an empty slice produces an empty loop)
8284        assert!(
8285            format_fn_body.contains("for msg in messages"),
8286            "format_chat_messages should iterate over the messages slice"
8287        );
8288        // It should always append the assistant prompt suffix
8289        assert!(
8290            format_fn_body.contains("<|im_start|>assistant"),
8291            "format_chat_messages should always append assistant prompt header"
8292        );
8293
8294        // The serve function should call model.reset() before each request
8295        let serve_fn_start = main_rs
8296            .find("fn serve(")
8297            .expect("serve function must exist");
8298        let serve_fn_body = &main_rs[serve_fn_start..];
8299        assert!(
8300            serve_fn_body.contains("model.reset()"),
8301            "serve function should reset model between requests"
8302        );
8303    }
8304
8305    #[test]
8306    fn generated_model_forward_increments_pos() {
8307        // Each forward() call must increment self.pos so the next token
8308        // uses the correct RoPE position and KV cache offset.
8309        let config = minimal_config();
8310        let model_rs = generate_model_rs(&config).unwrap();
8311
8312        let forward_start = model_rs
8313            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8314            .unwrap();
8315        let forward_body = &model_rs[forward_start..];
8316        let forward_end = forward_body
8317            .find("\n    pub fn forward_profile")
8318            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8319            .or_else(|| forward_body.find("\n    fn dispatch_"))
8320            .unwrap_or(forward_body.len());
8321        let forward_code = &forward_body[..forward_end];
8322
8323        assert!(
8324            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8325            "forward() must increment self.pos after processing a token"
8326        );
8327    }
8328}