Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    // The attention-scores shared array must be at least effective_seq_len
122    // elements; anything smaller silently overflows for long prompts.  For
123    // small-context models (135M at 2K), a smaller array saves threadgroup
124    // memory and improves occupancy — so we size it precisely.
125    let attn_scores_size = config.max_seq_len.min(4096);
126    r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154    device const float* matrix [[buffer(0)]],
155    device const float* vector [[buffer(1)]],
156    device float* output       [[buffer(2)]],
157    constant uint& rows        [[buffer(3)]],
158    constant uint& cols        [[buffer(4)]],
159    uint tgid [[threadgroup_position_in_grid]],
160    uint tid [[thread_index_in_threadgroup]],
161    uint simd_lane [[thread_index_in_simdgroup]],
162    uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164    // Cooperatively load vector into threadgroup shared memory
165    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166    for (uint i = tid; i < cols; i += 256) {
167        vec_tile[i] = vector[i];
168    }
169    threadgroup_barrier(mem_flags::mem_threadgroup);
170
171    // Each simdgroup handles 8 consecutive rows
172    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173    if (row_base >= rows) return;
174
175    uint base0 = row_base * cols;
176    uint base1 = (row_base + 1) * cols;
177    uint base2 = (row_base + 2) * cols;
178    uint base3 = (row_base + 3) * cols;
179    uint base4 = (row_base + 4) * cols;
180    uint base5 = (row_base + 5) * cols;
181    uint base6 = (row_base + 6) * cols;
182    uint base7 = (row_base + 7) * cols;
183
184    // float4 vectorized accumulation across 8 rows
185    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
186    float4 sum4_0 = float4(0.0f);
187    float4 sum4_1 = float4(0.0f);
188    float4 sum4_2 = float4(0.0f);
189    float4 sum4_3 = float4(0.0f);
190    float4 sum4_4 = float4(0.0f);
191    float4 sum4_5 = float4(0.0f);
192    float4 sum4_6 = float4(0.0f);
193    float4 sum4_7 = float4(0.0f);
194
195    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205    }
206
207    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216    // Handle remaining elements (cols not divisible by 128)
217    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218        float vv = vec_tile[j];
219        sum0 += matrix[base0 + j] * vv;
220        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227    }
228
229    // Simdgroup hardware warp-level reduction
230    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235    // Only first lane writes the results
236    if (simd_lane == 0) {
237        if (row_base     < rows) output[row_base]     = sum0;
238        if (row_base + 1 < rows) output[row_base + 1] = sum1;
239        if (row_base + 2 < rows) output[row_base + 2] = sum2;
240        if (row_base + 3 < rows) output[row_base + 3] = sum3;
241        if (row_base + 4 < rows) output[row_base + 4] = sum4;
242        if (row_base + 5 < rows) output[row_base + 5] = sum5;
243        if (row_base + 6 < rows) output[row_base + 6] = sum6;
244        if (row_base + 7 < rows) output[row_base + 7] = sum7;
245    }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253    device const float* input   [[buffer(0)]],
254    device const float* weight  [[buffer(1)]],
255    device float* output        [[buffer(2)]],
256    constant uint& n            [[buffer(3)]],
257    constant float& eps         [[buffer(4)]],
258    uint tid [[thread_index_in_threadgroup]])
259{
260    // Each thread accumulates partial sum-of-squares
261    float sum_sq = 0.0f;
262    for (uint i = tid; i < n; i += 256) {
263        float v = input[i];
264        sum_sq += v * v;
265    }
266
267    // Simdgroup-level reduction (hardware warp sum)
268    sum_sq = simd_sum(sum_sq);
269
270    // Cross-simdgroup reduction via shared memory
271    threadgroup float shared[8];
272    uint simd_id = tid / 32;
273    uint simd_lane = tid % 32;
274    if (simd_lane == 0) {
275        shared[simd_id] = sum_sq;
276    }
277    threadgroup_barrier(mem_flags::mem_threadgroup);
278
279    // First thread computes final inverse RMS
280    if (tid == 0) {
281        float total = 0.0f;
282        for (uint i = 0; i < 8; i++) {
283            total += shared[i];
284        }
285        shared[0] = fast::rsqrt(total / float(n) + eps);
286    }
287    threadgroup_barrier(mem_flags::mem_threadgroup);
288
289    float inv_rms = shared[0];
290
291    // Normalize
292    for (uint i = tid; i < n; i += 256) {
293        output[i] = input[i] * inv_rms * weight[i];
294    }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301    device float* data        [[buffer(0)]],
302    constant uint& num_heads  [[buffer(1)]],
303    constant uint& head_dim   [[buffer(2)]],
304    constant uint& pos        [[buffer(3)]],
305    constant float& theta     [[buffer(4)]],
306    uint id [[thread_position_in_grid]])
307{
308    uint half_dim = head_dim / 2;
309    uint total_pairs = num_heads * half_dim;
310    if (id >= total_pairs) return;
311
312    uint h = id / half_dim;
313    uint i = id % half_dim;
314    uint off = h * head_dim;
315
316    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317    float angle = float(pos) * freq;
318    float c = cos(angle);
319    float s = sin(angle);
320
321    float x0 = data[off + 2 * i];
322    float x1 = data[off + 2 * i + 1];
323    data[off + 2 * i]     = x0 * c - x1 * s;
324    data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331    device float* data       [[buffer(0)]],
332    constant uint& n         [[buffer(1)]],
333    uint tid [[thread_index_in_threadgroup]],
334    uint tg_size [[threads_per_threadgroup]])
335{
336    threadgroup float shared_val[256];
337
338    // Pass 1: find max
339    float local_max = -INFINITY;
340    for (uint i = tid; i < n; i += tg_size) {
341        local_max = max(local_max, data[i]);
342    }
343    shared_val[tid] = local_max;
344    threadgroup_barrier(mem_flags::mem_threadgroup);
345
346    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347        if (tid < stride) {
348            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349        }
350        threadgroup_barrier(mem_flags::mem_threadgroup);
351    }
352    float max_val = shared_val[0];
353    threadgroup_barrier(mem_flags::mem_threadgroup);
354
355    // Pass 2: exp and sum
356    float local_sum = 0.0f;
357    for (uint i = tid; i < n; i += tg_size) {
358        float e = fast::exp(data[i] - max_val);
359        data[i] = e;
360        local_sum += e;
361    }
362    shared_val[tid] = local_sum;
363    threadgroup_barrier(mem_flags::mem_threadgroup);
364
365    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366        if (tid < stride) {
367            shared_val[tid] += shared_val[tid + stride];
368        }
369        threadgroup_barrier(mem_flags::mem_threadgroup);
370    }
371    float inv_sum = 1.0f / shared_val[0];
372    threadgroup_barrier(mem_flags::mem_threadgroup);
373
374    // Pass 3: normalize
375    for (uint i = tid; i < n; i += tg_size) {
376        data[i] *= inv_sum;
377    }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384    device const float* gate [[buffer(0)]],
385    device const float* up   [[buffer(1)]],
386    device float* output     [[buffer(2)]],
387    constant uint& n         [[buffer(3)]],
388    uint id [[thread_position_in_grid]])
389{
390    if (id >= n) return;
391    float g = gate[id];
392    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397//   gate = gate_up[0..n], up = gate_up[n..2*n]
398//   output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400    device const float* gate_up [[buffer(0)]],
401    device float* output        [[buffer(1)]],
402    constant uint& n            [[buffer(2)]],
403    uint id [[thread_position_in_grid]])
404{
405    if (id >= n) return;
406    float g = gate_up[id];
407    float u = gate_up[n + id];
408    output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414    device const float* a  [[buffer(0)]],
415    device const float* b  [[buffer(1)]],
416    device float* output   [[buffer(2)]],
417    constant uint& n       [[buffer(3)]],
418    uint id [[thread_position_in_grid]])
419{
420    if (id >= n) return;
421    output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428    device const float* src [[buffer(0)]],
429    device float* dst       [[buffer(1)]],
430    constant uint& count    [[buffer(2)]],
431    uint id [[thread_position_in_grid]])
432{
433    if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440    device const float* src     [[buffer(0)]],
441    device float* dst           [[buffer(1)]],
442    constant uint& src_offset   [[buffer(2)]],  // in floats
443    constant uint& count        [[buffer(3)]],
444    uint id [[thread_position_in_grid]])
445{
446    if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── add_inplace ─────────────────────────────────────────────────────────
450// In-place residual connection: a[i] += b[i]
451// Avoids a separate blit copy for residual add, reducing encoder overhead.
452kernel void add_inplace(
453    device float* a        [[buffer(0)]],
454    device const float* b  [[buffer(1)]],
455    constant uint& n       [[buffer(2)]],
456    uint id [[thread_position_in_grid]])
457{
458    if (id >= n) return;
459    a[id] += b[id];
460}
461
462// ── matmul_vec_q8 ─────────────────────────────────────────────────────
463// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
464// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
465// Operates directly on quantized weights to halve memory bandwidth vs f32,
466// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
467//
468// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
469// because int8->float conversion doubles register demand.  Fully unrolled
470// inner loop with float4 vector loads from shared memory eliminates loop
471// overhead and enables better instruction scheduling.
472// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
473constant constexpr uint Q8_ROWS_PER_SG = 4;
474constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
475
476// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
477// doubles ALU work, so the same register budget applies).
478constant constexpr uint Q4_ROWS_PER_SG = 4;
479constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
480
481kernel void matmul_vec_q8(
482    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
483    device const float* vector   [[buffer(1)]],  // f32 input
484    device float* output         [[buffer(2)]],
485    constant uint& rows          [[buffer(3)]],
486    constant uint& cols          [[buffer(4)]],  // number of elements per row
487    uint tgid [[threadgroup_position_in_grid]],
488    uint tid [[thread_index_in_threadgroup]],
489    uint simd_lane [[thread_index_in_simdgroup]],
490    uint simd_id [[simdgroup_index_in_threadgroup]])
491{
492    // Load vector into shared memory
493    threadgroup float vec_tile[VEC_TILE_SIZE];
494    for (uint i = tid; i < cols; i += 256) {
495        vec_tile[i] = vector[i];
496    }
497    threadgroup_barrier(mem_flags::mem_threadgroup);
498
499    // Each simdgroup handles 4 consecutive rows (lower register pressure)
500    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
501    if (row_base >= rows) return;
502
503    // Q8_0: each block is 34 bytes for 32 elements
504    uint blocks_per_row = cols / 32;
505    uint row_bytes = blocks_per_row * 34;
506
507    // Pointers to each row's Q8_0 data
508    device const uchar* r0 = matrix + row_base * row_bytes;
509    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
510    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
511    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
512
513    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
514
515    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
516        uint bb = blk * 34;
517        uint vb = blk * 32;
518
519        // Prefetch all 4 scales
520        float sc0 = float(*(device const half*)(r0 + bb));
521        float sc1 = float(*(device const half*)(r1 + bb));
522        float sc2 = float(*(device const half*)(r2 + bb));
523        float sc3 = float(*(device const half*)(r3 + bb));
524
525        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
526        // Q8_0 block layout where the int8 data starts at offset +2 from a
527        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
528        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
529        // reduction in memory transactions. Metal's char16/packed_char16 are
530        // reserved types and packed_*int4 require >=4-byte alignment which
531        // this layout does not provide, so packed_short4 is the widest valid
532        // vectorized load.
533        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
534        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
535        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
536        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
537
538        // Load all 8 float4 vector values for this 32-element block from shared memory
539        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
540        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
541        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
542        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
543        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
544        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
545        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
546        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
547
548        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
549        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
550        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
551            short4 _s = short4(SHORT4); \
552            char2 _a = as_type<char2>(_s.x); \
553            char2 _b = as_type<char2>(_s.y); \
554            char2 _c = as_type<char2>(_s.z); \
555            char2 _d = as_type<char2>(_s.w); \
556            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
557            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
558        }
559
560        float4 f0, f1;
561        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
562
563        // Row 0: 4 short4 loads cover 32 int8 weights
564        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
565        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
566        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
567        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
568
569        // Row 1
570        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
571        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
572        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
573        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
574
575        // Row 2
576        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
577        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
578        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
579        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
580
581        // Row 3
582        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
583        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
584        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
585        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
586
587        #undef Q8_UNPACK8
588
589        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
590    }
591
592    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
593    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
594
595    if (simd_lane == 0) {
596        if (row_base     < rows) output[row_base]     = sum0;
597        if (row_base + 1 < rows) output[row_base + 1] = sum1;
598        if (row_base + 2 < rows) output[row_base + 2] = sum2;
599        if (row_base + 3 < rows) output[row_base + 3] = sum3;
600    }
601}
602
603// ── matmul_vec_q4 ─────────────────────────────────────────────────────
604// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
605// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
606// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
607// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
608//
609// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
610// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
611kernel void matmul_vec_q4(
612    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
613    device const float* vector   [[buffer(1)]],  // f32 input
614    device float* output         [[buffer(2)]],
615    constant uint& rows          [[buffer(3)]],
616    constant uint& cols          [[buffer(4)]],  // number of elements per row
617    uint tgid [[threadgroup_position_in_grid]],
618    uint tid [[thread_index_in_threadgroup]],
619    uint simd_lane [[thread_index_in_simdgroup]],
620    uint simd_id [[simdgroup_index_in_threadgroup]])
621{
622    // Load vector into shared memory
623    threadgroup float vec_tile[VEC_TILE_SIZE];
624    for (uint i = tid; i < cols; i += 256) {
625        vec_tile[i] = vector[i];
626    }
627    threadgroup_barrier(mem_flags::mem_threadgroup);
628
629    // Each simdgroup handles 4 consecutive rows
630    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
631    if (row_base >= rows) return;
632
633    // Q4_0: each block is 18 bytes for 32 elements
634    uint blocks_per_row = cols / 32;
635    uint row_bytes = blocks_per_row * 18;
636
637    // Pointers to each row's Q4_0 data
638    device const uchar* r0 = matrix + row_base * row_bytes;
639    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
640    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
641    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
642
643    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
644
645    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
646        uint bb = blk * 18;
647        uint vb = blk * 32;
648
649        // Prefetch all 4 scales
650        float sc0 = float(*(device const half*)(r0 + bb));
651        float sc1 = float(*(device const half*)(r1 + bb));
652        float sc2 = float(*(device const half*)(r2 + bb));
653        float sc3 = float(*(device const half*)(r3 + bb));
654
655        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
656        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
657        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
658        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
659        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
660
661        // Load 8 float4 vector values for 32 elements from shared memory
662        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
663        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
664        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
665        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
666        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
667        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
668        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
669        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
670        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
671
672        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
673        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
674        float bd0=0, bd1=0, bd2=0, bd3=0;
675        uchar4 b;
676
677        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
678        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
679                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
680                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
681                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
682        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
683                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
684                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
685                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
686        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
687                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
688                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
689                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
690        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
691                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
692                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
693                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
694
695        // Row 1
696        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
697                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
698                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
699                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
700        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
701                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
702                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
703                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
704        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
705                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
706                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
707                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
708        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
709                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
710                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
711                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
712
713        // Row 2
714        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
715                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
716                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
717                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
718        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
719                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
720                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
721                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
722        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
723                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
724                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
725                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
726        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
727                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
728                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
729                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
730
731        // Row 3
732        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
733                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
734                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
735                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
736        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
737                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
738                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
739                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
740        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
741                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
742                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
743                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
744        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
745                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
746                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
747                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
748
749        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
750    }
751
752    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
753    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
754
755    if (simd_lane == 0) {
756        if (row_base     < rows) output[row_base]     = sum0;
757        if (row_base + 1 < rows) output[row_base + 1] = sum1;
758        if (row_base + 2 < rows) output[row_base + 2] = sum2;
759        if (row_base + 3 < rows) output[row_base + 3] = sum3;
760    }
761}
762
763// ── attention ───────────────────────────────────────────────────────────
764// Single-query attention with simdgroup cooperative reductions.
765// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
766// with simd_max/simd_sum reductions, then weighted sum of V.
767// Each threadgroup handles one head with 256 threads (8 simdgroups).
768//
769// Buffers:
770//   q:       [num_heads * head_dim]       current query
771//   k_cache: [max_seq_len * num_kv_heads * head_dim]
772//   v_cache: [max_seq_len * num_kv_heads * head_dim]
773//   output:  [num_heads * head_dim]
774kernel void attention(
775    device const float* q        [[buffer(0)]],
776    device const float* k_cache  [[buffer(1)]],
777    device const float* v_cache  [[buffer(2)]],
778    device float* output         [[buffer(3)]],
779    constant uint& seq_len       [[buffer(4)]],
780    constant uint& num_heads     [[buffer(5)]],
781    constant uint& num_kv_heads  [[buffer(6)]],
782    constant uint& head_dim      [[buffer(7)]],
783    uint tgid [[threadgroup_position_in_grid]],
784    uint tid [[thread_index_in_threadgroup]],
785    uint simd_lane [[thread_index_in_simdgroup]],
786    uint simd_id [[simdgroup_index_in_threadgroup]])
787{
788    uint head = tgid;
789    if (head >= num_heads) return;
790    uint kv_head = head / (num_heads / num_kv_heads);
791
792    uint q_off = head * head_dim;
793
794    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
795    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
796    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
797    // most generation steps have short effective context.
798    threadgroup float scores[ATTN_SCORES_SIZE];  // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
799
800    for (uint s = simd_id; s < seq_len; s += 8) {
801        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
802        float dot = 0.0;
803        for (uint d = simd_lane; d < head_dim; d += 32) {
804            dot += q[q_off + d] * k_cache[k_off + d];
805        }
806        dot = simd_sum(dot);
807        if (simd_lane == 0) {
808            scores[s] = dot * fast::rsqrt(float(head_dim));
809        }
810    }
811    threadgroup_barrier(mem_flags::mem_threadgroup);
812
813    // Step 2: Softmax over scores (cooperative)
814    // Find max
815    float local_max = -INFINITY;
816    for (uint s = tid; s < seq_len; s += 256) {
817        local_max = max(local_max, scores[s]);
818    }
819    local_max = simd_max(local_max);
820    threadgroup float shared_max[8];
821    if (simd_lane == 0) shared_max[simd_id] = local_max;
822    threadgroup_barrier(mem_flags::mem_threadgroup);
823    if (tid == 0) {
824        float m = shared_max[0];
825        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
826        shared_max[0] = m;
827    }
828    threadgroup_barrier(mem_flags::mem_threadgroup);
829    float max_val = shared_max[0];
830
831    // Exp and sum
832    float local_sum = 0.0;
833    for (uint s = tid; s < seq_len; s += 256) {
834        scores[s] = fast::exp(scores[s] - max_val);
835        local_sum += scores[s];
836    }
837    local_sum = simd_sum(local_sum);
838    threadgroup float shared_sum[8];
839    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
840    threadgroup_barrier(mem_flags::mem_threadgroup);
841    if (tid == 0) {
842        float total = 0.0;
843        for (uint i = 0; i < 8; i++) total += shared_sum[i];
844        shared_sum[0] = 1.0 / total;
845    }
846    threadgroup_barrier(mem_flags::mem_threadgroup);
847    float inv_sum = shared_sum[0];
848
849    for (uint s = tid; s < seq_len; s += 256) {
850        scores[s] *= inv_sum;
851    }
852    threadgroup_barrier(mem_flags::mem_threadgroup);
853
854    // Step 3: Weighted sum of V: output = scores · V
855    // Each thread handles a range of head_dim dimensions.
856    // Process 4 sequence positions at a time for better ILP and reduced
857    // loop overhead (float4 score gather, 4 V loads per iteration).
858    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
859    uint v_stride = num_kv_heads * head_dim;
860    for (uint d = tid; d < head_dim; d += 256) {
861        float acc = 0.0;
862        uint v_base = kv_head * head_dim + d;
863        for (uint s = 0; s < seq_len4; s += 4) {
864            float sc0 = scores[s];
865            float sc1 = scores[s + 1];
866            float sc2 = scores[s + 2];
867            float sc3 = scores[s + 3];
868            acc += sc0 * v_cache[s * v_stride + v_base]
869                 + sc1 * v_cache[(s+1) * v_stride + v_base]
870                 + sc2 * v_cache[(s+2) * v_stride + v_base]
871                 + sc3 * v_cache[(s+3) * v_stride + v_base];
872        }
873        for (uint s = seq_len4; s < seq_len; s++) {
874            acc += scores[s] * v_cache[s * v_stride + v_base];
875        }
876        output[q_off + d] = acc;
877    }
878}
879
880// ── Batched prefill kernels ────────────────────────────────────────────
881// These kernels process M input vectors against the same weight matrix
882// in a single dispatch, converting mat-vec into mat-mat for better GPU
883// utilization during prompt prefill.
884
885// ── rms_norm_batch ─────────────────────────────────────────────────────
886// RMS normalization for a batch of vectors.
887// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
888// Grid: M threadgroups (one per token).
889kernel void rms_norm_batch(
890    device const float* input   [[buffer(0)]],  // [M, n]
891    device const float* weight  [[buffer(1)]],  // [n]
892    device float* output        [[buffer(2)]],  // [M, n]
893    constant uint& n            [[buffer(3)]],
894    constant float& eps         [[buffer(4)]],
895    constant uint& num_tokens   [[buffer(5)]],
896    uint tgid [[threadgroup_position_in_grid]],
897    uint tid [[thread_index_in_threadgroup]])
898{
899    if (tgid >= num_tokens) return;
900
901    uint base = tgid * n;
902
903    float sum_sq = 0.0f;
904    for (uint i = tid; i < n; i += 256) {
905        float v = input[base + i];
906        sum_sq += v * v;
907    }
908
909    sum_sq = simd_sum(sum_sq);
910
911    threadgroup float shared[8];
912    uint simd_id = tid / 32;
913    uint simd_lane = tid % 32;
914    if (simd_lane == 0) {
915        shared[simd_id] = sum_sq;
916    }
917    threadgroup_barrier(mem_flags::mem_threadgroup);
918
919    if (tid == 0) {
920        float total = 0.0f;
921        for (uint i = 0; i < 8; i++) {
922            total += shared[i];
923        }
924        shared[0] = fast::rsqrt(total / float(n) + eps);
925    }
926    threadgroup_barrier(mem_flags::mem_threadgroup);
927
928    float inv_rms = shared[0];
929
930    for (uint i = tid; i < n; i += 256) {
931        output[base + i] = input[base + i] * inv_rms * weight[i];
932    }
933}
934
935// ── rope_batch ─────────────────────────────────────────────────────────
936// Rotary Position Embedding for a batch of vectors with different positions.
937// data layout: [M, num_heads * head_dim], positions: [M]
938// Each thread handles one (token, head, pair) combination.
939kernel void rope_batch(
940    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
941    constant uint& num_heads     [[buffer(1)]],
942    constant uint& head_dim      [[buffer(2)]],
943    device const uint* positions  [[buffer(3)]],  // [M] position per token
944    constant float& theta        [[buffer(4)]],
945    constant uint& num_tokens    [[buffer(5)]],
946    uint id [[thread_position_in_grid]])
947{
948    uint half_dim = head_dim / 2;
949    uint pairs_per_token = num_heads * half_dim;
950    uint total = num_tokens * pairs_per_token;
951    if (id >= total) return;
952
953    uint token = id / pairs_per_token;
954    uint rem = id % pairs_per_token;
955    uint h = rem / half_dim;
956    uint i = rem % half_dim;
957    uint off = token * (num_heads * head_dim) + h * head_dim;
958
959    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
960    float angle = float(positions[token]) * freq;
961    float c = cos(angle);
962    float s = sin(angle);
963
964    float x0 = data[off + 2 * i];
965    float x1 = data[off + 2 * i + 1];
966    data[off + 2 * i]     = x0 * c - x1 * s;
967    data[off + 2 * i + 1] = x0 * s + x1 * c;
968}
969
970// ── silu_mul_fused_batch ───────────────────────────────────────────────
971// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
972// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
973kernel void silu_mul_fused_batch(
974    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
975    device float* output        [[buffer(1)]],  // [M, n]
976    constant uint& n            [[buffer(2)]],
977    constant uint& num_tokens   [[buffer(3)]],
978    uint id [[thread_position_in_grid]])
979{
980    uint total = num_tokens * n;
981    if (id >= total) return;
982    uint token = id / n;
983    uint i = id % n;
984    uint gu_base = token * 2 * n;
985    float g = gate_up[gu_base + i];
986    float u = gate_up[gu_base + n + i];
987    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
988}
989
990// ── add_inplace_batch ──────────────────────────────────────────────────
991// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
992kernel void add_inplace_batch(
993    device float* a        [[buffer(0)]],  // [M * n]
994    device const float* b  [[buffer(1)]],  // [M * n]
995    constant uint& total   [[buffer(2)]],  // M * n
996    uint id [[thread_position_in_grid]])
997{
998    if (id >= total) return;
999    a[id] += b[id];
1000}
1001
1002// ── copy_embedding_batch ───────────────────────────────────────────────
1003// Copy M embedding rows from embedding table to a contiguous batch buffer.
1004// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1005kernel void copy_embedding_batch(
1006    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1007    device float* output        [[buffer(1)]],  // [M, dim]
1008    device const uint* tokens   [[buffer(2)]],  // [M]
1009    constant uint& dim          [[buffer(3)]],
1010    constant uint& num_tokens   [[buffer(4)]],
1011    uint id [[thread_position_in_grid]])
1012{
1013    uint total = num_tokens * dim;
1014    if (id >= total) return;
1015    uint token_idx = id / dim;
1016    uint d = id % dim;
1017    output[id] = embed[tokens[token_idx] * dim + d];
1018}
1019
1020// ── matmul_vec_batch ───────────────────────────────────────────────────
1021// Batched matrix-vector multiply: process M input vectors against the same
1022// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1023// Each threadgroup handles one (token, row_group) pair.
1024kernel void matmul_vec_batch(
1025    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1026    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1027    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1028    constant uint& num_tokens   [[buffer(3)]],  // M
1029    constant uint& rows         [[buffer(4)]],
1030    constant uint& cols         [[buffer(5)]],
1031    uint tgid [[threadgroup_position_in_grid]],
1032    uint tid [[thread_index_in_threadgroup]],
1033    uint simd_lane [[thread_index_in_simdgroup]],
1034    uint simd_id [[simdgroup_index_in_threadgroup]])
1035{
1036    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1037    uint token = tgid / row_tgs;
1038    uint tg_in_token = tgid % row_tgs;
1039    if (token >= num_tokens) return;
1040
1041    // Load this token's input vector into shared memory
1042    threadgroup float vec_tile[VEC_TILE_SIZE];
1043    device const float* input = inputs + token * cols;
1044    for (uint i = tid; i < cols; i += 256) {
1045        vec_tile[i] = input[i];
1046    }
1047    threadgroup_barrier(mem_flags::mem_threadgroup);
1048
1049    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1050    if (row_base >= rows) return;
1051
1052    uint base0 = row_base * cols;
1053    uint base1 = (row_base + 1) * cols;
1054    uint base2 = (row_base + 2) * cols;
1055    uint base3 = (row_base + 3) * cols;
1056    uint base4 = (row_base + 4) * cols;
1057    uint base5 = (row_base + 5) * cols;
1058    uint base6 = (row_base + 6) * cols;
1059    uint base7 = (row_base + 7) * cols;
1060
1061    uint cols_vec4 = cols & ~127u;
1062    float4 sum4_0 = float4(0.0f);
1063    float4 sum4_1 = float4(0.0f);
1064    float4 sum4_2 = float4(0.0f);
1065    float4 sum4_3 = float4(0.0f);
1066    float4 sum4_4 = float4(0.0f);
1067    float4 sum4_5 = float4(0.0f);
1068    float4 sum4_6 = float4(0.0f);
1069    float4 sum4_7 = float4(0.0f);
1070
1071    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1072        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1073        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1074        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1075        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1076        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1077        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1078        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1079        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1080        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1081    }
1082
1083    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1084    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1085    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1086    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1087    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1088    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1089    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1090    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1091
1092    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1093        float vv = vec_tile[j];
1094        sum0 += matrix[base0 + j] * vv;
1095        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1096        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1097        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1098        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1099        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1100        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1101        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1102    }
1103
1104    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1105    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1106    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1107    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1108
1109    device float* output = outputs + token * rows;
1110    if (simd_lane == 0) {
1111        if (row_base     < rows) output[row_base]     = sum0;
1112        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1113        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1114        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1115        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1116        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1117        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1118        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1119    }
1120}
1121
1122// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1123// Batched Q8_0 matrix-vector multiply for M input vectors.
1124// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1125kernel void matmul_vec_q8_batch(
1126    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1127    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1128    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1129    constant uint& num_tokens    [[buffer(3)]],  // M
1130    constant uint& rows          [[buffer(4)]],
1131    constant uint& cols          [[buffer(5)]],
1132    uint tgid [[threadgroup_position_in_grid]],
1133    uint tid [[thread_index_in_threadgroup]],
1134    uint simd_lane [[thread_index_in_simdgroup]],
1135    uint simd_id [[simdgroup_index_in_threadgroup]])
1136{
1137    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1138    uint token = tgid / row_tgs;
1139    uint tg_in_token = tgid % row_tgs;
1140    if (token >= num_tokens) return;
1141
1142    // Load this token's input vector into shared memory
1143    threadgroup float vec_tile[VEC_TILE_SIZE];
1144    device const float* input = inputs + token * cols;
1145    for (uint i = tid; i < cols; i += 256) {
1146        vec_tile[i] = input[i];
1147    }
1148    threadgroup_barrier(mem_flags::mem_threadgroup);
1149
1150    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1151    if (row_base >= rows) return;
1152
1153    uint blocks_per_row = cols / 32;
1154    uint row_bytes = blocks_per_row * 34;
1155
1156    device const uchar* r0 = matrix + row_base * row_bytes;
1157    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1158    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1159    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1160
1161    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1162
1163    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1164        uint bb = blk * 34;
1165        uint vb = blk * 32;
1166
1167        float sc0 = float(*(device const half*)(r0 + bb));
1168        float sc1 = float(*(device const half*)(r1 + bb));
1169        float sc2 = float(*(device const half*)(r2 + bb));
1170        float sc3 = float(*(device const half*)(r3 + bb));
1171
1172        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1173        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1174        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1175        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1176        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1177        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1178
1179        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1180        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1181        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1182        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1183        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1184        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1185        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1186        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1187
1188        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1189            short4 _s = short4(SHORT4); \
1190            char2 _a = as_type<char2>(_s.x); \
1191            char2 _b = as_type<char2>(_s.y); \
1192            char2 _c = as_type<char2>(_s.z); \
1193            char2 _d = as_type<char2>(_s.w); \
1194            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1195            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1196        }
1197
1198        float4 f0, f1;
1199        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1200
1201        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1202        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1203        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1204        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1205
1206        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1207        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1208        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1209        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1210
1211        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1212        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1213        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1214        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1215
1216        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1217        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1218        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1219        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1220
1221        #undef Q8_UNPACK8
1222
1223        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1224    }
1225
1226    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1227    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1228
1229    device float* output = outputs + token * rows;
1230    if (simd_lane == 0) {
1231        if (row_base     < rows) output[row_base]     = sum0;
1232        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1233        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1234        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1235    }
1236}
1237
1238// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1239// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1240// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1241// the Q8_0 weight blocks are fetched once from device memory and reused for
1242// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1243// per-token dispatch).
1244//
1245// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1246// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1247// with simd_sum.  Token vectors are read directly from device memory inside
1248// the block loop (not cached in shared memory) so intermediate_size up to
1249// 8192 fits without spilling threadgroup memory.
1250constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1251
1252kernel void matmul_q8_gemm_batch(
1253    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1254    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1255    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1256    constant uint& num_tokens    [[buffer(3)]],  // M
1257    constant uint& rows          [[buffer(4)]],
1258    constant uint& cols          [[buffer(5)]],
1259    uint2 tgid [[threadgroup_position_in_grid]],
1260    uint tid [[thread_index_in_threadgroup]],
1261    uint simd_lane [[thread_index_in_simdgroup]],
1262    uint simd_id [[simdgroup_index_in_threadgroup]])
1263{
1264    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1265    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1266    if (row_base >= rows || tok_base >= num_tokens) return;
1267
1268    // How many tokens in this tile are valid?
1269    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1270
1271    uint blocks_per_row = cols / 32;
1272    uint row_bytes = blocks_per_row * 34;
1273
1274    device const uchar* r0 = matrix + row_base * row_bytes;
1275    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1276    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1277    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1278
1279    // Accumulators: 4 tokens × 4 rows per simdgroup.
1280    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1281    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1282    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1283    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1284
1285    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1286        uint bb = blk * 34;
1287        uint vb = blk * 32;
1288
1289        // ── Load weight data ONCE per block (reused across all tokens) ──
1290        float sc0 = float(*(device const half*)(r0 + bb));
1291        float sc1 = float(*(device const half*)(r1 + bb));
1292        float sc2 = float(*(device const half*)(r2 + bb));
1293        float sc3 = float(*(device const half*)(r3 + bb));
1294
1295        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1296        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1297        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1298        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1299
1300        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1301            short4 _s = short4(SHORT4); \
1302            char2 _a = as_type<char2>(_s.x); \
1303            char2 _b = as_type<char2>(_s.y); \
1304            char2 _c = as_type<char2>(_s.z); \
1305            char2 _d = as_type<char2>(_s.w); \
1306            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1307            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1308        }
1309
1310        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1311        // registers for the duration of the block and are dotted against
1312        // every token's vector tile.
1313        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1314        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1315        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1316        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1317
1318        Q8_UNPACK8(d0[0], w0_0, w0_1);
1319        Q8_UNPACK8(d0[1], w0_2, w0_3);
1320        Q8_UNPACK8(d0[2], w0_4, w0_5);
1321        Q8_UNPACK8(d0[3], w0_6, w0_7);
1322
1323        Q8_UNPACK8(d1[0], w1_0, w1_1);
1324        Q8_UNPACK8(d1[1], w1_2, w1_3);
1325        Q8_UNPACK8(d1[2], w1_4, w1_5);
1326        Q8_UNPACK8(d1[3], w1_6, w1_7);
1327
1328        Q8_UNPACK8(d2[0], w2_0, w2_1);
1329        Q8_UNPACK8(d2[1], w2_2, w2_3);
1330        Q8_UNPACK8(d2[2], w2_4, w2_5);
1331        Q8_UNPACK8(d2[3], w2_6, w2_7);
1332
1333        Q8_UNPACK8(d3[0], w3_0, w3_1);
1334        Q8_UNPACK8(d3[1], w3_2, w3_3);
1335        Q8_UNPACK8(d3[2], w3_4, w3_5);
1336        Q8_UNPACK8(d3[3], w3_6, w3_7);
1337
1338        #undef Q8_UNPACK8
1339
1340        // ── For each token, read vector and accumulate against shared weights ──
1341        // Token 0 (always valid: tok_count >= 1).
1342        {
1343            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1344            float4 v0 = *(device const float4*)(a0);
1345            float4 v1 = *(device const float4*)(a0 + 4);
1346            float4 v2 = *(device const float4*)(a0 + 8);
1347            float4 v3 = *(device const float4*)(a0 + 12);
1348            float4 v4 = *(device const float4*)(a0 + 16);
1349            float4 v5 = *(device const float4*)(a0 + 20);
1350            float4 v6 = *(device const float4*)(a0 + 24);
1351            float4 v7 = *(device const float4*)(a0 + 28);
1352            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1353                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1354            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1355                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1356            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1357                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1358            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1359                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1360            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1361        }
1362        // Token 1
1363        if (tok_count > 1) {
1364            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1365            float4 v0 = *(device const float4*)(a1);
1366            float4 v1 = *(device const float4*)(a1 + 4);
1367            float4 v2 = *(device const float4*)(a1 + 8);
1368            float4 v3 = *(device const float4*)(a1 + 12);
1369            float4 v4 = *(device const float4*)(a1 + 16);
1370            float4 v5 = *(device const float4*)(a1 + 20);
1371            float4 v6 = *(device const float4*)(a1 + 24);
1372            float4 v7 = *(device const float4*)(a1 + 28);
1373            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1374                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1375            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1376                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1377            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1378                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1379            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1380                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1381            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1382        }
1383        // Token 2
1384        if (tok_count > 2) {
1385            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1386            float4 v0 = *(device const float4*)(a2);
1387            float4 v1 = *(device const float4*)(a2 + 4);
1388            float4 v2 = *(device const float4*)(a2 + 8);
1389            float4 v3 = *(device const float4*)(a2 + 12);
1390            float4 v4 = *(device const float4*)(a2 + 16);
1391            float4 v5 = *(device const float4*)(a2 + 20);
1392            float4 v6 = *(device const float4*)(a2 + 24);
1393            float4 v7 = *(device const float4*)(a2 + 28);
1394            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1395                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1396            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1397                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1398            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1399                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1400            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1401                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1402            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1403        }
1404        // Token 3
1405        if (tok_count > 3) {
1406            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1407            float4 v0 = *(device const float4*)(a3);
1408            float4 v1 = *(device const float4*)(a3 + 4);
1409            float4 v2 = *(device const float4*)(a3 + 8);
1410            float4 v3 = *(device const float4*)(a3 + 12);
1411            float4 v4 = *(device const float4*)(a3 + 16);
1412            float4 v5 = *(device const float4*)(a3 + 20);
1413            float4 v6 = *(device const float4*)(a3 + 24);
1414            float4 v7 = *(device const float4*)(a3 + 28);
1415            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1416                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1417            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1418                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1419            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1420                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1421            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1422                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1423            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1424        }
1425    }
1426
1427    // simdgroup reduction
1428    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1429    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1430    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1431    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1432
1433    if (simd_lane == 0) {
1434        device float* o0 = outputs + (tok_base + 0) * rows;
1435        if (row_base     < rows) o0[row_base]     = s00;
1436        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1437        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1438        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1439
1440        if (tok_count > 1) {
1441            device float* o1 = outputs + (tok_base + 1) * rows;
1442            if (row_base     < rows) o1[row_base]     = s10;
1443            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1444            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1445            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1446        }
1447        if (tok_count > 2) {
1448            device float* o2 = outputs + (tok_base + 2) * rows;
1449            if (row_base     < rows) o2[row_base]     = s20;
1450            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1451            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1452            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1453        }
1454        if (tok_count > 3) {
1455            device float* o3 = outputs + (tok_base + 3) * rows;
1456            if (row_base     < rows) o3[row_base]     = s30;
1457            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1458            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1459            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1460        }
1461    }
1462}
1463
1464// ── matmul_q8_mma ──────────────────────────────────────────────────────
1465// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1466// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1467// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1468// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1469//
1470// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1471// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1472// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1473// dequantized into threadgroup memory once per block and reused by all
1474// simdgroups in the tile.
1475//
1476// Assumptions (verified in the dispatch helper, falls back otherwise):
1477//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1478//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1479//   * num_tokens may be any value; partial row at the tile boundary is handled
1480//     via a scratch copy path.
1481constant constexpr uint MMA_TOK_TILE = 16;
1482constant constexpr uint MMA_ROW_TILE = 16;
1483
1484kernel void matmul_q8_mma(
1485    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1486    device const float* inputs   [[buffer(1)]],  // [M, cols]
1487    device float* outputs        [[buffer(2)]],  // [M, rows]
1488    constant uint& num_tokens    [[buffer(3)]],
1489    constant uint& rows          [[buffer(4)]],
1490    constant uint& cols          [[buffer(5)]],
1491    uint2 tgid [[threadgroup_position_in_grid]],
1492    uint tid [[thread_index_in_threadgroup]],
1493    uint simd_id [[simdgroup_index_in_threadgroup]])
1494{
1495    uint row_base = tgid.x * MMA_ROW_TILE;
1496    uint tok_base = tgid.y * MMA_TOK_TILE;
1497    if (row_base >= rows || tok_base >= num_tokens) return;
1498
1499    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1500    threadgroup float w_tile[MMA_ROW_TILE * 32];
1501    threadgroup float t_tile[MMA_TOK_TILE * 32];
1502
1503    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1504    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1505    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1506
1507    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1508
1509    uint blocks_per_row = cols / 32;
1510    uint row_bytes = blocks_per_row * 34;
1511
1512    for (uint blk = 0; blk < blocks_per_row; blk++) {
1513        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1514        // 512 floats / 128 threads = 4 floats per thread.
1515        {
1516            uint base = tid * 4;
1517            for (uint ii = 0; ii < 4; ii++) {
1518                uint idx = base + ii;
1519                uint r = idx / 32;
1520                uint k = idx % 32;
1521                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1522                float sc = float(*(device const half*)rp);
1523                int ival = int(*(device const int8_t*)(rp + 2 + k));
1524                w_tile[r * 32 + k] = float(ival) * sc;
1525            }
1526        }
1527
1528        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1529        {
1530            uint base = tid * 4;
1531            for (uint ii = 0; ii < 4; ii++) {
1532                uint idx = base + ii;
1533                uint m = idx / 32;
1534                uint k = idx % 32;
1535                uint tok = tok_base + m;
1536                t_tile[m * 32 + k] = (tok < num_tokens)
1537                    ? inputs[tok * cols + blk * 32 + k]
1538                    : 0.0f;
1539            }
1540        }
1541
1542        threadgroup_barrier(mem_flags::mem_threadgroup);
1543
1544        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1545        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1546        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1547        // C[m, r] += A[m, k] * B[k, r]
1548        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1549            simdgroup_matrix<float, 8, 8> A, B;
1550            simdgroup_load(A,
1551                t_tile + sg_tok_base * 32 + k_sub * 8,
1552                32,
1553                ulong2(0, 0),
1554                false);
1555            simdgroup_load(B,
1556                w_tile + sg_row_base * 32 + k_sub * 8,
1557                32,
1558                ulong2(0, 0),
1559                true);
1560            simdgroup_multiply_accumulate(C, A, B, C);
1561        }
1562
1563        threadgroup_barrier(mem_flags::mem_threadgroup);
1564    }
1565
1566    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1567    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1568    uint out_tok = tok_base + sg_tok_base;
1569    uint out_row = row_base + sg_row_base;
1570    bool full_tok = (out_tok + 8 <= num_tokens);
1571    if (full_tok) {
1572        // Fast path: entire 8×8 sub-tile is in-bounds.
1573        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1574    } else if (out_tok < num_tokens) {
1575        // Partial row at the last token tile: stage in per-simdgroup scratch
1576        // and scalar-copy the valid rows.
1577        threadgroup float scratch[4 * 64];
1578        simdgroup_store(C, scratch + simd_id * 64, 8);
1579        simdgroup_barrier(mem_flags::mem_threadgroup);
1580        uint lane = tid % 32;
1581        if (lane == 0) {
1582            uint valid = num_tokens - out_tok;  // 1..7
1583            for (uint m = 0; m < valid; m++) {
1584                device float* dst = outputs + (out_tok + m) * rows + out_row;
1585                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1586                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1587                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1588            }
1589        }
1590    }
1591}
1592
1593// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1594// Larger-tile variant of matmul_q8_mma for long-context prefill.
1595//
1596// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1597// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1598// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1599//
1600//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1601//     output sub-tiles (tok, row):
1602//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1603//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1604//
1605// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1606// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1607// kernel — and halves the number of threadgroups vs the 16×16 tile.
1608//
1609// Assumptions (verified in dispatch helper, fallback otherwise):
1610//   * cols % 32 == 0
1611//   * rows % 32 == 0
1612constant constexpr uint MMA32_TOK_TILE = 32;
1613constant constexpr uint MMA32_ROW_TILE = 32;
1614
1615kernel void matmul_q8_mma32(
1616    device const uchar* matrix   [[buffer(0)]],
1617    device const float* inputs   [[buffer(1)]],
1618    device float* outputs        [[buffer(2)]],
1619    constant uint& num_tokens    [[buffer(3)]],
1620    constant uint& rows          [[buffer(4)]],
1621    constant uint& cols          [[buffer(5)]],
1622    uint2 tgid [[threadgroup_position_in_grid]],
1623    uint tid [[thread_index_in_threadgroup]],
1624    uint simd_id [[simdgroup_index_in_threadgroup]])
1625{
1626    uint row_base = tgid.x * MMA32_ROW_TILE;
1627    uint tok_base = tgid.y * MMA32_TOK_TILE;
1628    if (row_base >= rows || tok_base >= num_tokens) return;
1629
1630    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1631    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1632    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1633
1634    uint sg_tok_idx  = simd_id / 2;      // 0..3
1635    uint sg_row_half = simd_id % 2;      // 0..1
1636    uint sg_tok_base = sg_tok_idx * 8;
1637    uint sg_row_base_a = sg_row_half * 16 + 0;
1638    uint sg_row_base_b = sg_row_half * 16 + 8;
1639
1640    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1641    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1642
1643    uint blocks_per_row = cols / 32;
1644    uint row_bytes = blocks_per_row * 34;
1645
1646    for (uint blk = 0; blk < blocks_per_row; blk++) {
1647        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1648        {
1649            uint base = tid * 4;
1650            for (uint ii = 0; ii < 4; ii++) {
1651                uint idx = base + ii;
1652                uint r = idx / 32;
1653                uint k = idx % 32;
1654                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1655                float sc = float(*(device const half*)rp);
1656                int ival = int(*(device const int8_t*)(rp + 2 + k));
1657                w_tile[r * 32 + k] = float(ival) * sc;
1658            }
1659        }
1660
1661        // Cooperative token tile load.
1662        {
1663            uint base = tid * 4;
1664            for (uint ii = 0; ii < 4; ii++) {
1665                uint idx = base + ii;
1666                uint m = idx / 32;
1667                uint k = idx % 32;
1668                uint tok = tok_base + m;
1669                t_tile[m * 32 + k] = (tok < num_tokens)
1670                    ? inputs[tok * cols + blk * 32 + k]
1671                    : 0.0f;
1672            }
1673        }
1674
1675        threadgroup_barrier(mem_flags::mem_threadgroup);
1676
1677        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1678        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1679            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1680            simdgroup_load(A,
1681                t_tile + sg_tok_base * 32 + k_sub * 8,
1682                32,
1683                ulong2(0, 0),
1684                false);
1685            simdgroup_load(B_a,
1686                w_tile + sg_row_base_a * 32 + k_sub * 8,
1687                32,
1688                ulong2(0, 0),
1689                true);
1690            simdgroup_load(B_b,
1691                w_tile + sg_row_base_b * 32 + k_sub * 8,
1692                32,
1693                ulong2(0, 0),
1694                true);
1695            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1696            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1697        }
1698
1699        threadgroup_barrier(mem_flags::mem_threadgroup);
1700    }
1701
1702    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1703    // (verified in dispatch), so full simdgroup_store is safe for the row
1704    // dimension; only the last token tile may be partial.
1705    uint out_tok = tok_base + sg_tok_base;
1706    uint out_row_a = row_base + sg_row_base_a;
1707    uint out_row_b = row_base + sg_row_base_b;
1708    bool full_tok = (out_tok + 8 <= num_tokens);
1709    if (full_tok) {
1710        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1711        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1712    } else if (out_tok < num_tokens) {
1713        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1714        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1715        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1716        simdgroup_barrier(mem_flags::mem_threadgroup);
1717        uint lane = tid % 32;
1718        if (lane == 0) {
1719            uint valid = num_tokens - out_tok;  // 1..7
1720            for (uint m = 0; m < valid; m++) {
1721                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1722                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1723                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1724                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1725                for (uint j = 0; j < 8; j++) {
1726                    dst_a[j] = src_a[j];
1727                    dst_b[j] = src_b[j];
1728                }
1729            }
1730        }
1731    }
1732}
1733
1734// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1735// FP16 threadgroup-tile variant of matmul_q8_mma32.
1736//
1737// Stores dequantized weights and token inputs as `half` in threadgroup
1738// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1739// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1740// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1741// representation preserves the full quantized dynamic range.  Token
1742// activations stay numerically safe because the subsequent
1743// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1744//
1745// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1746// accumulators each.  Primary win vs mma32 is occupancy at moderate
1747// prefill lengths where the GPU is wave-starved.
1748kernel void matmul_q8_mma32_h(
1749    device const uchar* matrix   [[buffer(0)]],
1750    device const float* inputs   [[buffer(1)]],
1751    device float* outputs        [[buffer(2)]],
1752    constant uint& num_tokens    [[buffer(3)]],
1753    constant uint& rows          [[buffer(4)]],
1754    constant uint& cols          [[buffer(5)]],
1755    uint2 tgid [[threadgroup_position_in_grid]],
1756    uint tid [[thread_index_in_threadgroup]],
1757    uint simd_id [[simdgroup_index_in_threadgroup]])
1758{
1759    uint row_base = tgid.x * MMA32_ROW_TILE;
1760    uint tok_base = tgid.y * MMA32_TOK_TILE;
1761    if (row_base >= rows || tok_base >= num_tokens) return;
1762
1763    // 32×32 half tiles — 2 KB each, 4 KB total.
1764    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1765    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1766
1767    uint sg_tok_idx  = simd_id / 2;
1768    uint sg_row_half = simd_id % 2;
1769    uint sg_tok_base = sg_tok_idx * 8;
1770    uint sg_row_base_a = sg_row_half * 16 + 0;
1771    uint sg_row_base_b = sg_row_half * 16 + 8;
1772
1773    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1774    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1775
1776    uint blocks_per_row = cols / 32;
1777    uint row_bytes = blocks_per_row * 34;
1778
1779    for (uint blk = 0; blk < blocks_per_row; blk++) {
1780        // Cooperative weight dequantization to FP16.
1781        {
1782            uint base = tid * 4;
1783            for (uint ii = 0; ii < 4; ii++) {
1784                uint idx = base + ii;
1785                uint r = idx / 32;
1786                uint k = idx % 32;
1787                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1788                float sc = float(*(device const half*)rp);
1789                int ival = int(*(device const int8_t*)(rp + 2 + k));
1790                w_tile[r * 32 + k] = half(float(ival) * sc);
1791            }
1792        }
1793
1794        // Cooperative token tile load (f32 → f16 narrowing).
1795        {
1796            uint base = tid * 4;
1797            for (uint ii = 0; ii < 4; ii++) {
1798                uint idx = base + ii;
1799                uint m = idx / 32;
1800                uint k = idx % 32;
1801                uint tok = tok_base + m;
1802                t_tile[m * 32 + k] = (tok < num_tokens)
1803                    ? half(inputs[tok * cols + blk * 32 + k])
1804                    : half(0);
1805            }
1806        }
1807
1808        threadgroup_barrier(mem_flags::mem_threadgroup);
1809
1810        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1811            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1812            simdgroup_load(A,
1813                t_tile + sg_tok_base * 32 + k_sub * 8,
1814                32,
1815                ulong2(0, 0),
1816                false);
1817            simdgroup_load(B_a,
1818                w_tile + sg_row_base_a * 32 + k_sub * 8,
1819                32,
1820                ulong2(0, 0),
1821                true);
1822            simdgroup_load(B_b,
1823                w_tile + sg_row_base_b * 32 + k_sub * 8,
1824                32,
1825                ulong2(0, 0),
1826                true);
1827            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1828            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1829        }
1830
1831        threadgroup_barrier(mem_flags::mem_threadgroup);
1832    }
1833
1834    uint out_tok = tok_base + sg_tok_base;
1835    uint out_row_a = row_base + sg_row_base_a;
1836    uint out_row_b = row_base + sg_row_base_b;
1837    bool full_tok = (out_tok + 8 <= num_tokens);
1838    if (full_tok) {
1839        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1840        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1841    } else if (out_tok < num_tokens) {
1842        threadgroup float scratch[8 * 2 * 64];
1843        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1844        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1845        simdgroup_barrier(mem_flags::mem_threadgroup);
1846        uint lane = tid % 32;
1847        if (lane == 0) {
1848            uint valid = num_tokens - out_tok;
1849            for (uint m = 0; m < valid; m++) {
1850                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1851                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1852                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1853                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1854                for (uint j = 0; j < 8; j++) {
1855                    dst_a[j] = src_a[j];
1856                    dst_b[j] = src_b[j];
1857                }
1858            }
1859        }
1860    }
1861}
1862
1863// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1864// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1865//
1866// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1867// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1868//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1869//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1870// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1871//
1872// Per K_sub iteration: load two A fragments and two B fragments, then run
1873// **four** MMA instructions reusing A_top with both B's and A_bot with
1874// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1875// 2-accumulator kernel and halves the thread count per threadgroup (128
1876// threads), which often improves occupancy on Apple GPUs where the
1877// concurrent-thread budget is the tighter limit than shared-memory size.
1878kernel void matmul_q8_mma32_h4(
1879    device const uchar* matrix   [[buffer(0)]],
1880    device const float* inputs   [[buffer(1)]],
1881    device float* outputs        [[buffer(2)]],
1882    constant uint& num_tokens    [[buffer(3)]],
1883    constant uint& rows          [[buffer(4)]],
1884    constant uint& cols          [[buffer(5)]],
1885    uint2 tgid [[threadgroup_position_in_grid]],
1886    uint tid [[thread_index_in_threadgroup]],
1887    uint simd_id [[simdgroup_index_in_threadgroup]])
1888{
1889    uint row_base = tgid.x * MMA32_ROW_TILE;
1890    uint tok_base = tgid.y * MMA32_TOK_TILE;
1891    if (row_base >= rows || tok_base >= num_tokens) return;
1892
1893    // 32×32 FP16 tiles, 4 KB total.
1894    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1895    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1896
1897    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1898    uint sg_tok_q = simd_id / 2;   // 0..1
1899    uint sg_row_q = simd_id % 2;   // 0..1
1900    uint sg_tok_base = sg_tok_q * 16;
1901    uint sg_row_base = sg_row_q * 16;
1902
1903    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1904    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1905    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1906    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1907
1908    uint blocks_per_row = cols / 32;
1909    uint row_bytes = blocks_per_row * 34;
1910
1911    for (uint blk = 0; blk < blocks_per_row; blk++) {
1912        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1913        {
1914            uint base = tid * 8;
1915            for (uint ii = 0; ii < 8; ii++) {
1916                uint idx = base + ii;
1917                uint r = idx / 32;
1918                uint k = idx % 32;
1919                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1920                float sc = float(*(device const half*)rp);
1921                int ival = int(*(device const int8_t*)(rp + 2 + k));
1922                w_tile[r * 32 + k] = half(float(ival) * sc);
1923            }
1924        }
1925
1926        // Cooperative token tile load.
1927        {
1928            uint base = tid * 8;
1929            for (uint ii = 0; ii < 8; ii++) {
1930                uint idx = base + ii;
1931                uint m = idx / 32;
1932                uint k = idx % 32;
1933                uint tok = tok_base + m;
1934                t_tile[m * 32 + k] = (tok < num_tokens)
1935                    ? half(inputs[tok * cols + blk * 32 + k])
1936                    : half(0);
1937            }
1938        }
1939
1940        threadgroup_barrier(mem_flags::mem_threadgroup);
1941
1942        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1943        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1944            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1945            simdgroup_load(A_top,
1946                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1947                32,
1948                ulong2(0, 0),
1949                false);
1950            simdgroup_load(A_bot,
1951                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1952                32,
1953                ulong2(0, 0),
1954                false);
1955            simdgroup_load(B_lo,
1956                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1957                32,
1958                ulong2(0, 0),
1959                true);
1960            simdgroup_load(B_hi,
1961                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1962                32,
1963                ulong2(0, 0),
1964                true);
1965            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1966            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1967            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1968            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1969        }
1970
1971        threadgroup_barrier(mem_flags::mem_threadgroup);
1972    }
1973
1974    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
1975    uint out_tok_top = tok_base + sg_tok_base + 0;
1976    uint out_tok_bot = tok_base + sg_tok_base + 8;
1977    uint out_row_lo  = row_base + sg_row_base + 0;
1978    uint out_row_hi  = row_base + sg_row_base + 8;
1979    bool full = (out_tok_bot + 8 <= num_tokens);
1980    if (full) {
1981        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1982        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1983        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1984        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1985    } else {
1986        // Partial-token fallback via per-simdgroup scratch.
1987        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
1988        uint sg_base = simd_id * 256;
1989        simdgroup_store(C_00, scratch + sg_base +   0, 8);
1990        simdgroup_store(C_01, scratch + sg_base +  64, 8);
1991        simdgroup_store(C_10, scratch + sg_base + 128, 8);
1992        simdgroup_store(C_11, scratch + sg_base + 192, 8);
1993        simdgroup_barrier(mem_flags::mem_threadgroup);
1994        uint lane = tid % 32;
1995        if (lane == 0) {
1996            for (uint m = 0; m < 8; m++) {
1997                uint t_top = out_tok_top + m;
1998                if (t_top < num_tokens) {
1999                    device float* dst0 = outputs + t_top * rows + out_row_lo;
2000                    device float* dst1 = outputs + t_top * rows + out_row_hi;
2001                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
2002                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
2003                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2004                }
2005                uint t_bot = out_tok_bot + m;
2006                if (t_bot < num_tokens) {
2007                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2008                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2009                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2010                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2011                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2012                }
2013            }
2014        }
2015    }
2016}
2017
2018// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2019// All-half MMA variant of matmul_q8_mma32_h4.
2020//
2021// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2022// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2023// rate (dual-issue FMA), so if Q8_0 precision holds through half
2024// accumulation this kernel can double the effective FLOP throughput on
2025// matmul-bound prefill.
2026//
2027// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2028// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2029// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2030// precision on extreme values, but the inputs are already quantized so the
2031// per-product error floor is higher than the half-precision rounding error.
2032// We verify correctness on 135M / 1B / 3B before enabling.
2033kernel void matmul_q8_mma32_hh4(
2034    device const uchar* matrix   [[buffer(0)]],
2035    device const float* inputs   [[buffer(1)]],
2036    device float* outputs        [[buffer(2)]],
2037    constant uint& num_tokens    [[buffer(3)]],
2038    constant uint& rows          [[buffer(4)]],
2039    constant uint& cols          [[buffer(5)]],
2040    uint2 tgid [[threadgroup_position_in_grid]],
2041    uint tid [[thread_index_in_threadgroup]],
2042    uint simd_id [[simdgroup_index_in_threadgroup]])
2043{
2044    uint row_base = tgid.x * MMA32_ROW_TILE;
2045    uint tok_base = tgid.y * MMA32_TOK_TILE;
2046    if (row_base >= rows || tok_base >= num_tokens) return;
2047
2048    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2049    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2050
2051    uint sg_tok_q = simd_id / 2;
2052    uint sg_row_q = simd_id % 2;
2053    uint sg_tok_base = sg_tok_q * 16;
2054    uint sg_row_base = sg_row_q * 16;
2055
2056    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2057    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2058    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2059    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2060
2061    uint blocks_per_row = cols / 32;
2062    uint row_bytes = blocks_per_row * 34;
2063
2064    for (uint blk = 0; blk < blocks_per_row; blk++) {
2065        {
2066            uint base = tid * 8;
2067            for (uint ii = 0; ii < 8; ii++) {
2068                uint idx = base + ii;
2069                uint r = idx / 32;
2070                uint k = idx % 32;
2071                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2072                float sc = float(*(device const half*)rp);
2073                int ival = int(*(device const int8_t*)(rp + 2 + k));
2074                w_tile[r * 32 + k] = half(float(ival) * sc);
2075            }
2076        }
2077        {
2078            uint base = tid * 8;
2079            for (uint ii = 0; ii < 8; ii++) {
2080                uint idx = base + ii;
2081                uint m = idx / 32;
2082                uint k = idx % 32;
2083                uint tok = tok_base + m;
2084                t_tile[m * 32 + k] = (tok < num_tokens)
2085                    ? half(inputs[tok * cols + blk * 32 + k])
2086                    : half(0);
2087            }
2088        }
2089        threadgroup_barrier(mem_flags::mem_threadgroup);
2090
2091        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2092            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2093            simdgroup_load(A_top,
2094                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2095                32, ulong2(0, 0), false);
2096            simdgroup_load(A_bot,
2097                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2098                32, ulong2(0, 0), false);
2099            simdgroup_load(B_lo,
2100                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2101                32, ulong2(0, 0), true);
2102            simdgroup_load(B_hi,
2103                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2104                32, ulong2(0, 0), true);
2105            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2106            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2107            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2108            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2109        }
2110
2111        threadgroup_barrier(mem_flags::mem_threadgroup);
2112    }
2113
2114    // Store half accumulators via scratch (must widen to f32 for device output).
2115    uint out_tok_top = tok_base + sg_tok_base + 0;
2116    uint out_tok_bot = tok_base + sg_tok_base + 8;
2117    uint out_row_lo  = row_base + sg_row_base + 0;
2118    uint out_row_hi  = row_base + sg_row_base + 8;
2119
2120    threadgroup half scratch[4 * 4 * 64];
2121    uint sg_base = simd_id * 256;
2122    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2123    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2124    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2125    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2126    simdgroup_barrier(mem_flags::mem_threadgroup);
2127    uint lane = tid % 32;
2128    if (lane == 0) {
2129        for (uint m = 0; m < 8; m++) {
2130            uint t_top = out_tok_top + m;
2131            if (t_top < num_tokens) {
2132                device float* dst0 = outputs + t_top * rows + out_row_lo;
2133                device float* dst1 = outputs + t_top * rows + out_row_hi;
2134                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2135                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2136                for (uint j = 0; j < 8; j++) {
2137                    dst0[j] = float(src0[j]);
2138                    dst1[j] = float(src1[j]);
2139                }
2140            }
2141            uint t_bot = out_tok_bot + m;
2142            if (t_bot < num_tokens) {
2143                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2144                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2145                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2146                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2147                for (uint j = 0; j < 8; j++) {
2148                    dst2[j] = float(src2[j]);
2149                    dst3[j] = float(src3[j]);
2150                }
2151            }
2152        }
2153    }
2154}
2155
2156// ── add_bias_batch ─────────────────────────────────────────────────────
2157// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2158// Used for Qwen2 QKV bias after the fused qkv matmul.
2159//     out[token, i] += bias[i]    for i in 0..rows, token in 0..num_tokens
2160kernel void add_bias_batch(
2161    device float* out            [[buffer(0)]],  // [num_tokens, rows]
2162    device const float* bias     [[buffer(1)]],  // [rows]
2163    constant uint& num_tokens    [[buffer(2)]],
2164    constant uint& rows          [[buffer(3)]],
2165    uint id [[thread_position_in_grid]])
2166{
2167    uint total = num_tokens * rows;
2168    if (id >= total) return;
2169    uint i = id % rows;
2170    out[id] += bias[i];
2171}
2172
2173// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2174// Batched Q4_0 matrix-vector multiply for M input vectors.
2175// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2176kernel void matmul_vec_q4_batch(
2177    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2178    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2179    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2180    constant uint& num_tokens    [[buffer(3)]],  // M
2181    constant uint& rows          [[buffer(4)]],
2182    constant uint& cols          [[buffer(5)]],
2183    uint tgid [[threadgroup_position_in_grid]],
2184    uint tid [[thread_index_in_threadgroup]],
2185    uint simd_lane [[thread_index_in_simdgroup]],
2186    uint simd_id [[simdgroup_index_in_threadgroup]])
2187{
2188    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2189    uint token = tgid / row_tgs;
2190    uint tg_in_token = tgid % row_tgs;
2191    if (token >= num_tokens) return;
2192
2193    threadgroup float vec_tile[VEC_TILE_SIZE];
2194    device const float* input = inputs + token * cols;
2195    for (uint i = tid; i < cols; i += 256) {
2196        vec_tile[i] = input[i];
2197    }
2198    threadgroup_barrier(mem_flags::mem_threadgroup);
2199
2200    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2201    if (row_base >= rows) return;
2202
2203    uint blocks_per_row = cols / 32;
2204    uint row_bytes = blocks_per_row * 18;
2205
2206    device const uchar* r0 = matrix + row_base * row_bytes;
2207    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2208    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2209    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2210
2211    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2212
2213    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2214        uint bb = blk * 18;
2215        uint vb = blk * 32;
2216
2217        float sc0 = float(*(device const half*)(r0 + bb));
2218        float sc1 = float(*(device const half*)(r1 + bb));
2219        float sc2 = float(*(device const half*)(r2 + bb));
2220        float sc3 = float(*(device const half*)(r3 + bb));
2221
2222        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2223        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2224        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2225        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2226
2227        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2228        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2229        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2230        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2231        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2232        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2233        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2234        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2235
2236        float bd0=0, bd1=0, bd2=0, bd3=0;
2237        uchar4 b;
2238
2239        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2240        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2241        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2242        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2243
2244        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2245        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2246        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2247        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2248
2249        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2250        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2251        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2252        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2253
2254        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2255        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2256        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2257        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2258
2259        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2260    }
2261
2262    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2263    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2264
2265    device float* output = outputs + token * rows;
2266    if (simd_lane == 0) {
2267        if (row_base     < rows) output[row_base]     = sum0;
2268        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2269        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2270        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2271    }
2272}
2273
2274// ── copy_kv_batch ─────────────────────────────────────────────────────
2275// Copy K or V from a strided batch QKV buffer to the KV cache.
2276// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2277// dst layout: contiguous [max_seq, kv_dim] cache.
2278kernel void copy_kv_batch(
2279    device const float* src  [[buffer(0)]],  // batch QKV buffer
2280    device float* dst        [[buffer(1)]],  // KV cache
2281    constant uint& M         [[buffer(2)]],  // num tokens in batch
2282    constant uint& kv_dim    [[buffer(3)]],  // floats per KV vector
2283    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2284    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2285    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2286    uint id [[thread_position_in_grid]])
2287{
2288    uint total = M * kv_dim;
2289    if (id >= total) return;
2290    uint token = id / kv_dim;
2291    uint d = id % kv_dim;
2292    uint dst_off = (base_pos + token) * kv_dim + d;
2293    uint src_off = token * src_stride + src_offset + d;
2294    dst[dst_off] = src[src_off];
2295}
2296
2297// ── attention_batch ───────────────────────────────────────────────────
2298// Batched causal attention for prefill. Processes M tokens in one dispatch.
2299// Each threadgroup handles one (token, head) pair.
2300// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2301// Causal masking: token i can only attend to positions 0..base_pos+i.
2302kernel void attention_batch(
2303    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2304    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2305    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2306    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2307    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2308    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2309    constant uint& num_heads         [[buffer(6)]],
2310    constant uint& num_kv_heads      [[buffer(7)]],
2311    constant uint& head_dim          [[buffer(8)]],
2312    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2313    uint tgid [[threadgroup_position_in_grid]],
2314    uint tid [[thread_index_in_threadgroup]],
2315    uint simd_lane [[thread_index_in_simdgroup]],
2316    uint simd_id [[simdgroup_index_in_threadgroup]])
2317{
2318    // Grid: M * num_heads threadgroups
2319    uint token_idx = tgid / num_heads;
2320    uint head = tgid % num_heads;
2321    if (token_idx >= M) return;
2322
2323    uint kv_head = head / (num_heads / num_kv_heads);
2324    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2325
2326    // Q offset uses strided layout (from batch QKV buffer)
2327    uint q_off = token_idx * q_stride + head * head_dim;
2328    // Output is contiguous [M, num_heads * head_dim]
2329    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2330
2331    // Shared memory for attention scores — sized to the effective max_seq_len
2332    // (4096 for all supported models) so long-context attention doesn't overflow.
2333    threadgroup float scores[ATTN_SCORES_SIZE];
2334
2335    // Step 1: Q * K^T with simdgroup reduction
2336    for (uint s = simd_id; s < seq_len; s += 8) {
2337        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2338        float dot = 0.0;
2339        for (uint d = simd_lane; d < head_dim; d += 32) {
2340            dot += q_batch[q_off + d] * k_cache[k_off + d];
2341        }
2342        dot = simd_sum(dot);
2343        if (simd_lane == 0) {
2344            scores[s] = dot * fast::rsqrt(float(head_dim));
2345        }
2346    }
2347    threadgroup_barrier(mem_flags::mem_threadgroup);
2348
2349    // Step 2: Softmax (cooperative)
2350    float local_max = -INFINITY;
2351    for (uint s = tid; s < seq_len; s += 256) {
2352        local_max = max(local_max, scores[s]);
2353    }
2354    local_max = simd_max(local_max);
2355    threadgroup float shared_max[8];
2356    if (simd_lane == 0) shared_max[simd_id] = local_max;
2357    threadgroup_barrier(mem_flags::mem_threadgroup);
2358    if (tid == 0) {
2359        float m = shared_max[0];
2360        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2361        shared_max[0] = m;
2362    }
2363    threadgroup_barrier(mem_flags::mem_threadgroup);
2364    float max_val = shared_max[0];
2365
2366    float local_sum = 0.0;
2367    for (uint s = tid; s < seq_len; s += 256) {
2368        scores[s] = fast::exp(scores[s] - max_val);
2369        local_sum += scores[s];
2370    }
2371    local_sum = simd_sum(local_sum);
2372    threadgroup float shared_sum[8];
2373    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2374    threadgroup_barrier(mem_flags::mem_threadgroup);
2375    if (tid == 0) {
2376        float total = 0.0;
2377        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2378        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2379    }
2380    threadgroup_barrier(mem_flags::mem_threadgroup);
2381    float inv_sum = shared_sum[0];
2382    for (uint s = tid; s < seq_len; s += 256) {
2383        scores[s] *= inv_sum;
2384    }
2385    threadgroup_barrier(mem_flags::mem_threadgroup);
2386
2387    // Step 3: scores * V using float4 vectorized loads
2388    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2389    // This is much better than the scalar version where only 64 of 256 threads are active.
2390    uint v_stride = num_kv_heads * head_dim;
2391    uint head_dim4 = head_dim / 4;
2392    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2393        uint d = d4 * 4;
2394        float4 acc = float4(0.0);
2395        uint v_base = kv_head * head_dim + d;
2396        uint seq_len4 = seq_len & ~3u;
2397        for (uint s = 0; s < seq_len4; s += 4) {
2398            float sc0 = scores[s];
2399            float sc1 = scores[s + 1];
2400            float sc2 = scores[s + 2];
2401            float sc3 = scores[s + 3];
2402            acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
2403                 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
2404                 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
2405                 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
2406        }
2407        for (uint s = seq_len4; s < seq_len; s++) {
2408            acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
2409        }
2410        *(device float4*)(output_batch + out_off + d) = acc;
2411    }
2412    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2413    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2414        float acc = 0.0;
2415        uint v_base = kv_head * head_dim + d;
2416        for (uint s = 0; s < seq_len; s++) {
2417            acc += scores[s] * v_cache[s * v_stride + v_base];
2418        }
2419        output_batch[out_off + d] = acc;
2420    }
2421}
2422
2423// ── attention_flash_batch ─────────────────────────────────────────────
2424// Streaming attention with online softmax.  Same grid as attention_batch
2425// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2426// matrix is never materialized.  K/V positions are processed in a tile of
2427// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2428// the standard flash-attention recurrence:
2429//
2430//     m_new   = max(m_old, tile_max)
2431//     alpha   = exp(m_old - m_new)
2432//     l_new   = alpha * l_old + sum(exp(S - m_new))
2433//     O_new   = alpha * O_old + sum(exp(S - m_new) * V)
2434//     O_final = O / l
2435//
2436// This removes the `scores[2048]` cap in attention_batch (which silently
2437// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2438// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2439//
2440// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2441constant constexpr uint FLASH_K_TILE = 32;
2442constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2443
2444kernel void attention_flash_batch(
2445    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2446    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2447    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2448    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2449    constant uint& M                 [[buffer(4)]],
2450    constant uint& base_pos          [[buffer(5)]],
2451    constant uint& num_heads         [[buffer(6)]],
2452    constant uint& num_kv_heads      [[buffer(7)]],
2453    constant uint& head_dim          [[buffer(8)]],
2454    constant uint& q_stride          [[buffer(9)]],
2455    uint tgid [[threadgroup_position_in_grid]],
2456    uint tid [[thread_index_in_threadgroup]],
2457    uint simd_lane [[thread_index_in_simdgroup]],
2458    uint simd_id [[simdgroup_index_in_threadgroup]])
2459{
2460    uint token_idx = tgid / num_heads;
2461    uint head = tgid % num_heads;
2462    if (token_idx >= M) return;
2463
2464    uint kv_head = head / (num_heads / num_kv_heads);
2465    uint seq_len = base_pos + token_idx + 1;  // causal: attend to [0, base_pos + token_idx]
2466    uint q_off = token_idx * q_stride + head * head_dim;
2467    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2468
2469    // Threadgroup state:
2470    //   q_sh:      Q vector for this (token, head), loaded once
2471    //   o_sh:      running output vector, updated each K tile
2472    //   scores_sh: scores for the current K tile only
2473    //   stats:     [running max, running sum]  (see flash-attention recurrence)
2474    threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2475    threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2476    threadgroup float scores_sh[FLASH_K_TILE];
2477    threadgroup float stats[2];
2478    threadgroup float sg_scratch[8];  // simdgroup-level reduction buffer
2479
2480    // --- Load Q (one row) and zero the running O ---
2481    for (uint d = tid; d < head_dim; d += 256) {
2482        q_sh[d] = q_batch[q_off + d];
2483        o_sh[d] = 0.0f;
2484    }
2485    if (tid == 0) {
2486        stats[0] = -INFINITY;
2487        stats[1] = 0.0f;
2488    }
2489    threadgroup_barrier(mem_flags::mem_threadgroup);
2490
2491    float scale = fast::rsqrt(float(head_dim));
2492    uint v_stride = num_kv_heads * head_dim;
2493    uint v_base = kv_head * head_dim;
2494
2495    // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2496    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2497        uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2498
2499        // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2500        // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2501        for (uint ti = simd_id; ti < tile_n; ti += 8) {
2502            uint k_off = (kv_base + ti) * v_stride + v_base;  // same layout as V stride
2503            float dot = 0.0f;
2504            for (uint d = simd_lane; d < head_dim; d += 32) {
2505                dot += q_sh[d] * k_cache[k_off + d];
2506            }
2507            dot = simd_sum(dot);
2508            if (simd_lane == 0) {
2509                scores_sh[ti] = dot * scale;
2510            }
2511        }
2512        threadgroup_barrier(mem_flags::mem_threadgroup);
2513
2514        // [2] Tile max via cooperative reduction.
2515        float local_max = -INFINITY;
2516        for (uint s = tid; s < tile_n; s += 256) {
2517            local_max = max(local_max, scores_sh[s]);
2518        }
2519        local_max = simd_max(local_max);
2520        if (simd_lane == 0) {
2521            sg_scratch[simd_id] = local_max;
2522        }
2523        threadgroup_barrier(mem_flags::mem_threadgroup);
2524        // [3] Merge with running max, compute alpha, rescale running l.
2525        float m_new;
2526        float alpha;
2527        if (tid == 0) {
2528            float tile_max = sg_scratch[0];
2529            for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2530            float m_old = stats[0];
2531            m_new = max(m_old, tile_max);
2532            // First iteration: m_old = -inf → alpha = 0 (reset O).
2533            alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2534            stats[0] = m_new;
2535            stats[1] *= alpha;
2536            // Broadcast via sg_scratch.
2537            sg_scratch[0] = alpha;
2538            sg_scratch[1] = m_new;
2539        }
2540        threadgroup_barrier(mem_flags::mem_threadgroup);
2541        alpha = sg_scratch[0];
2542        m_new = sg_scratch[1];
2543
2544        // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2545        for (uint d = tid; d < head_dim; d += 256) {
2546            o_sh[d] *= alpha;
2547        }
2548        for (uint s = tid; s < tile_n; s += 256) {
2549            scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2550        }
2551        threadgroup_barrier(mem_flags::mem_threadgroup);
2552
2553        // [5] Tile sum → update running l.
2554        float local_sum = 0.0f;
2555        for (uint s = tid; s < tile_n; s += 256) {
2556            local_sum += scores_sh[s];
2557        }
2558        local_sum = simd_sum(local_sum);
2559        if (simd_lane == 0) {
2560            sg_scratch[simd_id] = local_sum;
2561        }
2562        threadgroup_barrier(mem_flags::mem_threadgroup);
2563        if (tid == 0) {
2564            float tile_sum = 0.0f;
2565            for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2566            stats[1] += tile_sum;
2567        }
2568        threadgroup_barrier(mem_flags::mem_threadgroup);
2569
2570        // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2571        for (uint d = tid; d < head_dim; d += 256) {
2572            float acc = 0.0f;
2573            for (uint s = 0; s < tile_n; s++) {
2574                acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2575            }
2576            o_sh[d] += acc;
2577        }
2578        threadgroup_barrier(mem_flags::mem_threadgroup);
2579    }
2580
2581    // --- Normalize and write output ---
2582    float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2583    for (uint d = tid; d < head_dim; d += 256) {
2584        output_batch[out_off + d] = o_sh[d] * inv_l;
2585    }
2586}
2587
2588// ── rope_qk_batch ─────────────────────────────────────────────────────
2589// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2590// launch + memory barrier per layer. Both Q and K live in the same
2591// qkv_data buffer at different offsets within each token's row.
2592// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2593kernel void rope_qk_batch(
2594    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2595    constant uint& M                 [[buffer(1)]],   // num tokens
2596    constant uint& base_pos          [[buffer(2)]],   // starting position
2597    constant uint& num_q_heads       [[buffer(3)]],
2598    constant uint& num_kv_heads      [[buffer(4)]],
2599    constant uint& head_dim          [[buffer(5)]],
2600    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2601    constant float& theta            [[buffer(7)]],
2602    uint id [[thread_position_in_grid]])
2603{
2604    uint half_dim = head_dim / 2;
2605    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2606    uint token = id / total_pairs;
2607    uint pair = id % total_pairs;
2608    if (token >= M) return;
2609
2610    uint pos = base_pos + token;
2611    uint q_pairs = num_q_heads * half_dim;
2612
2613    uint h, i, offset;
2614    if (pair < q_pairs) {
2615        // Q head
2616        h = pair / half_dim;
2617        i = pair % half_dim;
2618        offset = token * qkv_stride + h * head_dim + i * 2;
2619    } else {
2620        // K head
2621        uint kp = pair - q_pairs;
2622        h = kp / half_dim;
2623        i = kp % half_dim;
2624        uint k_start = num_q_heads * head_dim;
2625        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2626    }
2627
2628    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2629    float angle = float(pos) * freq;
2630    float cos_val = cos(angle);
2631    float sin_val = sin(angle);
2632
2633    float x0 = qkv_data[offset];
2634    float x1 = qkv_data[offset + 1];
2635    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2636    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2637}
2638
2639// ── copy_kv_both_batch ────────────────────────────────────────────────
2640// Fused K+V cache copy in a single dispatch: copies both K and V from
2641// the strided batch QKV buffer to their respective KV cache buffers.
2642// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2643kernel void copy_kv_both_batch(
2644    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride]
2645    device float* k_dst        [[buffer(1)]],  // K cache [max_seq, kv_dim]
2646    device float* v_dst        [[buffer(2)]],  // V cache [max_seq, kv_dim]
2647    constant uint& M           [[buffer(3)]],  // num tokens in batch
2648    constant uint& kv_dim      [[buffer(4)]],  // floats per KV vector
2649    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2650    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2651    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2652    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2653    uint id [[thread_position_in_grid]])
2654{
2655    // Total elements = M * kv_dim * 2 (K + V)
2656    uint total_kv = M * kv_dim;
2657    if (id >= total_kv * 2) return;
2658
2659    uint is_v = id / total_kv;        // 0 = K, 1 = V
2660    uint local_id = id % total_kv;
2661    uint token = local_id / kv_dim;
2662    uint d = local_id % kv_dim;
2663
2664    uint dst_off = (base_pos + token) * kv_dim + d;
2665    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2666
2667    if (is_v) {
2668        v_dst[dst_off] = src[src_off];
2669    } else {
2670        k_dst[dst_off] = src[src_off];
2671    }
2672}
2673"#
2674    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2675    .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2676}
2677
2678// ---------------------------------------------------------------------------
2679// model.rs generation
2680// ---------------------------------------------------------------------------
2681
2682fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2683    let mut code = String::with_capacity(48 * 1024);
2684    emit_model_header(&mut code, config)?;
2685    emit_metal_model_struct(&mut code, config)?;
2686    emit_layer_buffers_struct(&mut code, config)?;
2687    emit_metal_model_impl(&mut code, config)?;
2688    emit_helper_functions(&mut code)?;
2689    Ok(code)
2690}
2691
2692fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2693    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2694    writeln!(
2695        code,
2696        "//! Model: {} ({} layers, hidden={})",
2697        config.architecture, config.num_layers, config.hidden_size
2698    )?;
2699    writeln!(code, "//!")?;
2700    writeln!(
2701        code,
2702        "//! Uses native Metal compute pipelines via the metal crate."
2703    )?;
2704    writeln!(code)?;
2705    writeln!(code, "#![allow(dead_code)]")?;
2706    writeln!(code)?;
2707    writeln!(code, "use metal::*;")?;
2708    writeln!(code, "#[allow(unused_imports)]")?;
2709    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2710    writeln!(code, "use std::mem;")?;
2711    writeln!(code)?;
2712
2713    // Model constants
2714    writeln!(
2715        code,
2716        "// ── Model constants ──────────────────────────────────"
2717    )?;
2718    writeln!(
2719        code,
2720        "pub const HIDDEN_SIZE: usize = {};",
2721        config.hidden_size
2722    )?;
2723    writeln!(
2724        code,
2725        "pub const INTERMEDIATE_SIZE: usize = {};",
2726        config.intermediate_size
2727    )?;
2728    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2729    writeln!(
2730        code,
2731        "pub const NUM_HEADS: usize = {};",
2732        config.num_attention_heads
2733    )?;
2734    writeln!(
2735        code,
2736        "pub const NUM_KV_HEADS: usize = {};",
2737        config.num_kv_heads
2738    )?;
2739    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
2740    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
2741    let effective_seq_len = config.max_seq_len.min(4096);
2742    writeln!(
2743        code,
2744        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
2745        effective_seq_len, config.max_seq_len
2746    )?;
2747    writeln!(
2748        code,
2749        "pub const RMS_NORM_EPS: f32 = {:e};",
2750        config.rms_norm_eps
2751    )?;
2752    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
2753    writeln!(
2754        code,
2755        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
2756    )?;
2757    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
2758    writeln!(code)?;
2759
2760    Ok(())
2761}
2762
2763fn emit_metal_model_struct(
2764    code: &mut String,
2765    config: &ModelConfig,
2766) -> Result<(), MetalCodegenError> {
2767    writeln!(
2768        code,
2769        "// ── MetalModel ──────────────────────────────────────────"
2770    )?;
2771    writeln!(code)?;
2772    writeln!(
2773        code,
2774        "/// Metal-accelerated transformer model for Apple Silicon."
2775    )?;
2776    writeln!(code, "///")?;
2777    writeln!(
2778        code,
2779        "/// Uses unified memory for zero-copy weight access and native Metal"
2780    )?;
2781    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
2782    writeln!(code, "pub struct MetalModel {{")?;
2783    writeln!(code, "    device: Device,")?;
2784    writeln!(code, "    queue: CommandQueue,")?;
2785    writeln!(code)?;
2786    writeln!(code, "    // ── Compute pipelines ──")?;
2787    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
2788    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
2789    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
2790    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
2791    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
2792    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
2793    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
2794    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
2795    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
2796    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
2797    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
2798    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
2799    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
2800    writeln!(code)?;
2801    writeln!(code, "    // ── Batched prefill pipelines ──")?;
2802    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
2803    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
2804    writeln!(
2805        code,
2806        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
2807    )?;
2808    writeln!(
2809        code,
2810        "    matmul_q8_mma_pipeline: ComputePipelineState,"
2811    )?;
2812    writeln!(
2813        code,
2814        "    matmul_q8_mma32_pipeline: ComputePipelineState,"
2815    )?;
2816    writeln!(
2817        code,
2818        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
2819    )?;
2820    writeln!(
2821        code,
2822        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
2823    )?;
2824    writeln!(
2825        code,
2826        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
2827    )?;
2828    if config.qkv_bias {
2829        writeln!(
2830            code,
2831            "    add_bias_batch_pipeline: ComputePipelineState,"
2832        )?;
2833    }
2834    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
2835    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
2836    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
2837    writeln!(
2838        code,
2839        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
2840    )?;
2841    writeln!(
2842        code,
2843        "    add_inplace_batch_pipeline: ComputePipelineState,"
2844    )?;
2845    writeln!(
2846        code,
2847        "    copy_embedding_batch_pipeline: ComputePipelineState,"
2848    )?;
2849    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
2850    writeln!(
2851        code,
2852        "    attention_flash_batch_pipeline: ComputePipelineState,"
2853    )?;
2854    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
2855    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
2856    writeln!(
2857        code,
2858        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
2859    )?;
2860    writeln!(code)?;
2861    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
2862    writeln!(
2863        code,
2864        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
2865    )?;
2866    writeln!(code, "    embed_buf: Buffer,")?;
2867    writeln!(code)?;
2868    writeln!(code, "    /// Per-layer weight buffers")?;
2869    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
2870    writeln!(code)?;
2871    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
2872    writeln!(code, "    norm_buf: Buffer,")?;
2873    writeln!(code)?;
2874    writeln!(
2875        code,
2876        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
2877    )?;
2878    writeln!(code, "    lm_head_buf: Buffer,")?;
2879    writeln!(code)?;
2880    writeln!(
2881        code,
2882        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
2883    )?;
2884    writeln!(code, "    hidden_buf: Buffer,")?;
2885    writeln!(code, "    residual_buf: Buffer,")?;
2886    writeln!(code, "    normed_buf: Buffer,")?;
2887    writeln!(
2888        code,
2889        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
2890    )?;
2891    writeln!(code, "    qkv_buf: Buffer,")?;
2892    writeln!(code, "    attn_out_buf: Buffer,")?;
2893    writeln!(code, "    attn_proj_buf: Buffer,")?;
2894    writeln!(
2895        code,
2896        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
2897    )?;
2898    writeln!(code, "    gate_up_buf: Buffer,")?;
2899    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
2900    writeln!(code, "    ffn_out_buf: Buffer,")?;
2901    writeln!(code, "    add_tmp_buf: Buffer,")?;
2902    writeln!(code, "    logits_buf: Buffer,")?;
2903    writeln!(code)?;
2904    writeln!(code, "    // ── Batched prefill working buffers ──")?;
2905    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
2906    writeln!(code, "    batch_hidden_buf: Buffer,")?;
2907    writeln!(
2908        code,
2909        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
2910    )?;
2911    writeln!(code, "    batch_residual_buf: Buffer,")?;
2912    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
2913    writeln!(code, "    batch_qkv_buf: Buffer,")?;
2914    writeln!(
2915        code,
2916        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
2917    )?;
2918    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
2919    writeln!(
2920        code,
2921        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
2922    )?;
2923    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
2924    writeln!(
2925        code,
2926        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
2927    )?;
2928    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
2929    writeln!(
2930        code,
2931        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
2932    )?;
2933    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
2934    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
2935    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
2936    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
2937    writeln!(code, "    batch_tokens_buf: Buffer,")?;
2938    writeln!(code, "    /// Positions buffer for batch RoPE")?;
2939    writeln!(code, "    batch_positions_buf: Buffer,")?;
2940    writeln!(code)?;
2941    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
2942    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
2943    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
2944    writeln!(code)?;
2945    writeln!(code, "    // ── Inference state ──")?;
2946    writeln!(code, "    pos: usize,")?;
2947    writeln!(code)?;
2948    writeln!(
2949        code,
2950        "    /// Previous command buffer for double-buffered prefill."
2951    )?;
2952    writeln!(
2953        code,
2954        "    /// While the GPU executes token N, the CPU can encode token N+1."
2955    )?;
2956    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
2957    writeln!(code, "}}")?;
2958    writeln!(code)?;
2959
2960    Ok(())
2961}
2962
2963fn emit_layer_buffers_struct(
2964    code: &mut String,
2965    config: &ModelConfig,
2966) -> Result<(), MetalCodegenError> {
2967    writeln!(
2968        code,
2969        "/// Per-layer weight buffers for attention and FFN projections."
2970    )?;
2971    writeln!(code, "struct LayerBuffers {{")?;
2972    writeln!(code, "    attn_norm: Buffer,")?;
2973    writeln!(
2974        code,
2975        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
2976    )?;
2977    writeln!(code, "    qkv_weight: Buffer,")?;
2978    if config.qkv_bias {
2979        writeln!(
2980            code,
2981            "    /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
2982        )?;
2983        writeln!(code, "    qkv_bias: Buffer,")?;
2984    }
2985    writeln!(code, "    o_weight: Buffer,")?;
2986    writeln!(code, "    ffn_norm: Buffer,")?;
2987    writeln!(
2988        code,
2989        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
2990    )?;
2991    writeln!(code, "    gate_up_weight: Buffer,")?;
2992    writeln!(code, "    down_weight: Buffer,")?;
2993    writeln!(code, "}}")?;
2994    writeln!(code)?;
2995
2996    Ok(())
2997}
2998
2999fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3000    let hidden = config.hidden_size;
3001    let intermediate = config.intermediate_size;
3002    let _num_layers = config.num_layers;
3003    let num_heads = config.num_attention_heads;
3004    let num_kv_heads = config.num_kv_heads;
3005    let head_dim = config.head_dim;
3006    let vocab = config.vocab_size;
3007    let effective_seq_len = config.max_seq_len.min(4096);
3008    let is_q8 = config.dtype == DType::Q8_0;
3009    let is_q4 = config.dtype == DType::Q4_0;
3010    let kv_dim = num_kv_heads * head_dim;
3011
3012    writeln!(code, "impl MetalModel {{")?;
3013
3014    // ── new() ──
3015    writeln!(
3016        code,
3017        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3018    )?;
3019    writeln!(code, "    ///")?;
3020    writeln!(
3021        code,
3022        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
3023    )?;
3024    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
3025    writeln!(
3026        code,
3027        "        let device = Device::system_default().expect(\"no Metal device found\");"
3028    )?;
3029    writeln!(code, "        let queue = device.new_command_queue();")?;
3030    writeln!(code)?;
3031
3032    // Compile shaders
3033    writeln!(
3034        code,
3035        "        // Compile Metal shaders from embedded source"
3036    )?;
3037    writeln!(
3038        code,
3039        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
3040    )?;
3041    writeln!(
3042        code,
3043        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3044    )?;
3045    writeln!(
3046        code,
3047        "            .expect(\"failed to compile Metal shaders\");"
3048    )?;
3049    writeln!(code)?;
3050
3051    // Create compute pipelines
3052    writeln!(code, "        // Create compute pipelines")?;
3053    for (var, fn_name) in [
3054        ("matmul_pipeline", "matmul_vec"),
3055        ("matmul_q8_pipeline", "matmul_vec_q8"),
3056        ("matmul_q4_pipeline", "matmul_vec_q4"),
3057        ("rms_norm_pipeline", "rms_norm"),
3058        ("rope_pipeline", "rope"),
3059        ("softmax_pipeline", "softmax"),
3060        ("silu_mul_pipeline", "silu_mul"),
3061        ("silu_mul_fused_pipeline", "silu_mul_fused"),
3062        ("add_pipeline", "elementwise_add"),
3063        ("attention_pipeline", "attention"),
3064        ("add_inplace_pipeline", "add_inplace"),
3065        ("copy_pipeline", "copy_buffer"),
3066        ("copy_offset_pipeline", "copy_offset"),
3067        ("matmul_batch_pipeline", "matmul_vec_batch"),
3068        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3069        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3070        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3071        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3072        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3073        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3074        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3075        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3076        ("rms_norm_batch_pipeline", "rms_norm_batch"),
3077        ("rope_batch_pipeline", "rope_batch"),
3078        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3079        ("add_inplace_batch_pipeline", "add_inplace_batch"),
3080        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3081        ("attention_batch_pipeline", "attention_batch"),
3082        ("attention_flash_batch_pipeline", "attention_flash_batch"),
3083        ("copy_kv_batch_pipeline", "copy_kv_batch"),
3084        ("rope_qk_batch_pipeline", "rope_qk_batch"),
3085        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3086    ] {
3087        writeln!(
3088            code,
3089            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3090        )?;
3091    }
3092    if config.qkv_bias {
3093        writeln!(
3094            code,
3095            "        let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3096        )?;
3097    }
3098    writeln!(code)?;
3099
3100    // Weight loading
3101    writeln!(
3102        code,
3103        "        // Load weights into Metal shared-memory buffers"
3104    )?;
3105    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
3106    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3107    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
3108    writeln!(code)?;
3109    writeln!(
3110        code,
3111        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
3112    )?;
3113    writeln!(code)?;
3114    writeln!(
3115        code,
3116        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3117    )?;
3118    writeln!(
3119        code,
3120        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3121    )?;
3122    writeln!(code, "            let byte_len = n * f32_size;")?;
3123    writeln!(code, "            let cur = cursor.get();")?;
3124    writeln!(
3125        code,
3126        "            let data = &weights[cur..cur + byte_len];"
3127    )?;
3128    writeln!(code, "            cursor.set(cur + byte_len);")?;
3129    writeln!(code, "            device.new_buffer_with_data(")?;
3130    writeln!(code, "                data.as_ptr() as *const _,")?;
3131    writeln!(code, "                byte_len as u64,")?;
3132    writeln!(
3133        code,
3134        "                MTLResourceOptions::StorageModeShared,"
3135    )?;
3136    writeln!(code, "            )")?;
3137    writeln!(code, "        }};")?;
3138    writeln!(code)?;
3139
3140    if is_q8 {
3141        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
3142        // We load them directly into Metal buffers without dequantizing,
3143        // and use the matmul_vec_q8 shader that operates on quantized data.
3144        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
3145        writeln!(
3146            code,
3147            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3148        )?;
3149        writeln!(
3150            code,
3151            "        // as raw bytes into a Metal buffer (no dequantization)."
3152        )?;
3153        writeln!(
3154            code,
3155            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3156        )?;
3157        writeln!(
3158            code,
3159            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3160        )?;
3161        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3162        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3163        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3164        writeln!(code, "            let cur = cursor.get();")?;
3165        writeln!(
3166            code,
3167            "            let data = &weights[cur..cur + total_raw];"
3168        )?;
3169        writeln!(code, "            cursor.set(cur + total_raw);")?;
3170        writeln!(code, "            device.new_buffer_with_data(")?;
3171        writeln!(code, "                data.as_ptr() as *const _,")?;
3172        writeln!(code, "                total_raw as u64,")?;
3173        writeln!(
3174            code,
3175            "                MTLResourceOptions::StorageModeShared,"
3176        )?;
3177        writeln!(code, "            )")?;
3178        writeln!(code, "        }};")?;
3179        writeln!(code)?;
3180        writeln!(
3181            code,
3182            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3183        )?;
3184        writeln!(
3185            code,
3186            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3187        )?;
3188        writeln!(
3189            code,
3190            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3191        )?;
3192        writeln!(
3193            code,
3194            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3195        )?;
3196        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3197        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3198        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3199        writeln!(code, "            let cur = cursor.get();")?;
3200        writeln!(
3201            code,
3202            "            let data = &weights[cur..cur + total_raw];"
3203        )?;
3204        writeln!(code, "            cursor.set(cur + total_raw);")?;
3205        writeln!(code, "            device.new_buffer_with_data(")?;
3206        writeln!(code, "                data.as_ptr() as *const _,")?;
3207        writeln!(code, "                total_raw as u64,")?;
3208        writeln!(
3209            code,
3210            "                MTLResourceOptions::StorageModeShared,"
3211        )?;
3212        writeln!(code, "            )")?;
3213        writeln!(code, "        }};")?;
3214        writeln!(code)?;
3215    }
3216
3217    if is_q4 {
3218        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3219        // We load them directly into Metal buffers without dequantizing,
3220        // and use the matmul_vec_q4 shader that operates on quantized data.
3221        // This quarters GPU memory usage vs f32 dequantization.
3222        writeln!(
3223            code,
3224            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3225        )?;
3226        writeln!(
3227            code,
3228            "        // as raw bytes into a Metal buffer (no dequantization)."
3229        )?;
3230        writeln!(
3231            code,
3232            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3233        )?;
3234        writeln!(
3235            code,
3236            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3237        )?;
3238        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3239        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3240        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3241        writeln!(code, "            let cur = cursor.get();")?;
3242        writeln!(
3243            code,
3244            "            let data = &weights[cur..cur + total_raw];"
3245        )?;
3246        writeln!(code, "            cursor.set(cur + total_raw);")?;
3247        writeln!(code, "            device.new_buffer_with_data(")?;
3248        writeln!(code, "                data.as_ptr() as *const _,")?;
3249        writeln!(code, "                total_raw as u64,")?;
3250        writeln!(
3251            code,
3252            "                MTLResourceOptions::StorageModeShared,"
3253        )?;
3254        writeln!(code, "            )")?;
3255        writeln!(code, "        }};")?;
3256        writeln!(code)?;
3257        writeln!(
3258            code,
3259            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3260        )?;
3261        writeln!(
3262            code,
3263            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3264        )?;
3265        writeln!(
3266            code,
3267            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3268        )?;
3269        writeln!(
3270            code,
3271            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3272        )?;
3273        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3274        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3275        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3276        writeln!(code, "            let cur = cursor.get();")?;
3277        writeln!(
3278            code,
3279            "            let data = &weights[cur..cur + total_raw];"
3280        )?;
3281        writeln!(code, "            cursor.set(cur + total_raw);")?;
3282        writeln!(code, "            device.new_buffer_with_data(")?;
3283        writeln!(code, "                data.as_ptr() as *const _,")?;
3284        writeln!(code, "                total_raw as u64,")?;
3285        writeln!(
3286            code,
3287            "                MTLResourceOptions::StorageModeShared,"
3288        )?;
3289        writeln!(code, "            )")?;
3290        writeln!(code, "        }};")?;
3291        writeln!(code)?;
3292    }
3293
3294    writeln!(
3295        code,
3296        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3297    )?;
3298    writeln!(code)?;
3299
3300    // Per-layer weights
3301    writeln!(
3302        code,
3303        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3304    )?;
3305    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3306
3307    // attn_norm is always f32
3308    writeln!(
3309        code,
3310        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3311    )?;
3312
3313    let qkv_rows = hidden + 2 * kv_dim;
3314    if is_q8 {
3315        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3316        writeln!(
3317            code,
3318            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3319        )?;
3320        if config.qkv_bias {
3321            writeln!(
3322                code,
3323                "            // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3324            )?;
3325            writeln!(
3326                code,
3327                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3328            )?;
3329        }
3330        writeln!(
3331            code,
3332            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3333        )?;
3334    } else if is_q4 {
3335        writeln!(
3336            code,
3337            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3338        )?;
3339        if config.qkv_bias {
3340            writeln!(
3341                code,
3342                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3343            )?;
3344        }
3345        writeln!(
3346            code,
3347            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3348        )?;
3349    } else {
3350        writeln!(
3351            code,
3352            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3353        )?;
3354        if config.qkv_bias {
3355            writeln!(
3356                code,
3357                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3358            )?;
3359        }
3360        writeln!(
3361            code,
3362            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3363        )?;
3364    }
3365
3366    // ffn_norm is always f32
3367    writeln!(
3368        code,
3369        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3370    )?;
3371
3372    let gate_up_rows = 2 * intermediate;
3373    if is_q8 {
3374        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3375        writeln!(
3376            code,
3377            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3378        )?;
3379        writeln!(
3380            code,
3381            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3382        )?;
3383    } else if is_q4 {
3384        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3385        writeln!(
3386            code,
3387            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3388        )?;
3389        writeln!(
3390            code,
3391            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3392        )?;
3393    } else {
3394        // Fused gate+up weight: read both as a single contiguous f32 buffer
3395        writeln!(
3396            code,
3397            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3398        )?;
3399        writeln!(
3400            code,
3401            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3402        )?;
3403    }
3404
3405    writeln!(code, "            layers.push(LayerBuffers {{")?;
3406    writeln!(code, "                attn_norm,")?;
3407    writeln!(code, "                qkv_weight,")?;
3408    if config.qkv_bias {
3409        writeln!(code, "                qkv_bias,")?;
3410    }
3411    writeln!(code, "                o_weight,")?;
3412    writeln!(code, "                ffn_norm,")?;
3413    writeln!(code, "                gate_up_weight,")?;
3414    writeln!(code, "                down_weight,")?;
3415    writeln!(code, "            }});")?;
3416    writeln!(code, "        }}")?;
3417    writeln!(code)?;
3418
3419    // final_norm is always f32
3420    writeln!(
3421        code,
3422        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3423    )?;
3424    writeln!(code)?;
3425
3426    // lm_head
3427    if is_q8 {
3428        writeln!(
3429            code,
3430            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3431        )?;
3432    } else if is_q4 {
3433        writeln!(
3434            code,
3435            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3436        )?;
3437    } else {
3438        writeln!(
3439            code,
3440            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3441        )?;
3442    }
3443    writeln!(code)?;
3444
3445    // Working buffers
3446    let hidden_bytes = hidden * 4;
3447    let _kv_dim_bytes = kv_dim * 4;
3448    let intermediate_bytes = intermediate * 4;
3449    let vocab_bytes = vocab * 4;
3450    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
3451
3452    writeln!(
3453        code,
3454        "        // Allocate working buffers (shared memory for zero-copy)"
3455    )?;
3456    writeln!(
3457        code,
3458        "        let opts = MTLResourceOptions::StorageModeShared;"
3459    )?;
3460    writeln!(
3461        code,
3462        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3463    )?;
3464    writeln!(
3465        code,
3466        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3467    )?;
3468    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3469    writeln!(
3470        code,
3471        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3472    )?;
3473    writeln!(
3474        code,
3475        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3476    )?;
3477    writeln!(
3478        code,
3479        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3480    )?;
3481    writeln!(
3482        code,
3483        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3484    )?;
3485    writeln!(
3486        code,
3487        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3488    )?;
3489    let gate_up_buf_bytes = 2 * intermediate * 4;
3490    writeln!(
3491        code,
3492        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3493    )?;
3494    writeln!(
3495        code,
3496        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3497    )?;
3498    writeln!(
3499        code,
3500        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3501    )?;
3502    writeln!(
3503        code,
3504        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3505    )?;
3506    writeln!(
3507        code,
3508        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3509    )?;
3510    writeln!(
3511        code,
3512        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3513    )?;
3514    writeln!(code)?;
3515
3516    // Batch prefill working buffers
3517    let batch_hidden_bytes = hidden * 4; // per-token
3518    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3519    let batch_gate_up_bytes = 2 * intermediate * 4;
3520    let batch_intermediate_bytes = intermediate * 4;
3521    writeln!(
3522        code,
3523        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3524    )?;
3525    writeln!(
3526        code,
3527        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3528    )?;
3529    writeln!(
3530        code,
3531        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3532    )?;
3533    writeln!(
3534        code,
3535        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3536    )?;
3537    writeln!(
3538        code,
3539        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3540    )?;
3541    writeln!(
3542        code,
3543        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3544    )?;
3545    writeln!(
3546        code,
3547        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3548    )?;
3549    writeln!(
3550        code,
3551        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3552    )?;
3553    writeln!(
3554        code,
3555        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3556    )?;
3557    writeln!(
3558        code,
3559        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3560    )?;
3561    writeln!(
3562        code,
3563        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3564    )?;
3565    writeln!(code)?;
3566
3567    // KV cache buffers
3568    writeln!(code, "        // KV cache buffers (per-layer)")?;
3569    writeln!(
3570        code,
3571        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3572    )?;
3573    writeln!(
3574        code,
3575        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3576    )?;
3577    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3578    writeln!(
3579        code,
3580        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3581    )?;
3582    writeln!(
3583        code,
3584        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3585    )?;
3586    writeln!(code, "        }}")?;
3587    writeln!(code)?;
3588
3589    writeln!(code, "        Self {{")?;
3590    writeln!(code, "            device,")?;
3591    writeln!(code, "            queue,")?;
3592    writeln!(code, "            matmul_pipeline,")?;
3593    writeln!(code, "            matmul_q8_pipeline,")?;
3594    writeln!(code, "            matmul_q4_pipeline,")?;
3595    writeln!(code, "            rms_norm_pipeline,")?;
3596    writeln!(code, "            rope_pipeline,")?;
3597    writeln!(code, "            softmax_pipeline,")?;
3598    writeln!(code, "            silu_mul_pipeline,")?;
3599    writeln!(code, "            silu_mul_fused_pipeline,")?;
3600    writeln!(code, "            add_pipeline,")?;
3601    writeln!(code, "            attention_pipeline,")?;
3602    writeln!(code, "            add_inplace_pipeline,")?;
3603    writeln!(code, "            copy_pipeline,")?;
3604    writeln!(code, "            copy_offset_pipeline,")?;
3605    writeln!(code, "            matmul_batch_pipeline,")?;
3606    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3607    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3608    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3609    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3610    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3611    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3612    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3613    if config.qkv_bias {
3614        writeln!(code, "            add_bias_batch_pipeline,")?;
3615    }
3616    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3617    writeln!(code, "            rms_norm_batch_pipeline,")?;
3618    writeln!(code, "            rope_batch_pipeline,")?;
3619    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3620    writeln!(code, "            add_inplace_batch_pipeline,")?;
3621    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3622    writeln!(code, "            attention_batch_pipeline,")?;
3623    writeln!(code, "            attention_flash_batch_pipeline,")?;
3624    writeln!(code, "            copy_kv_batch_pipeline,")?;
3625    writeln!(code, "            rope_qk_batch_pipeline,")?;
3626    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3627    writeln!(code, "            embed_buf,")?;
3628    writeln!(code, "            layers,")?;
3629    writeln!(code, "            norm_buf,")?;
3630    writeln!(code, "            lm_head_buf,")?;
3631    writeln!(code, "            hidden_buf,")?;
3632    writeln!(code, "            residual_buf,")?;
3633    writeln!(code, "            normed_buf,")?;
3634    writeln!(code, "            qkv_buf,")?;
3635    writeln!(code, "            attn_out_buf,")?;
3636    writeln!(code, "            attn_proj_buf,")?;
3637    writeln!(code, "            gate_up_buf,")?;
3638    writeln!(code, "            ffn_hidden_buf,")?;
3639    writeln!(code, "            ffn_out_buf,")?;
3640    writeln!(code, "            add_tmp_buf,")?;
3641    writeln!(code, "            logits_buf,")?;
3642    writeln!(code, "            batch_hidden_buf,")?;
3643    writeln!(code, "            batch_residual_buf,")?;
3644    writeln!(code, "            batch_qkv_buf,")?;
3645    writeln!(code, "            batch_attn_out_buf,")?;
3646    writeln!(code, "            batch_attn_proj_buf,")?;
3647    writeln!(code, "            batch_gate_up_buf,")?;
3648    writeln!(code, "            batch_ffn_hidden_buf,")?;
3649    writeln!(code, "            batch_ffn_out_buf,")?;
3650    writeln!(code, "            batch_tokens_buf,")?;
3651    writeln!(code, "            batch_positions_buf,")?;
3652    writeln!(code, "            k_cache,")?;
3653    writeln!(code, "            v_cache,")?;
3654    writeln!(code, "            pos: 0,")?;
3655    writeln!(code, "            prev_cmd: None,")?;
3656    writeln!(code, "        }}")?;
3657    writeln!(code, "    }}")?;
3658    writeln!(code)?;
3659
3660    // ── forward() ──
3661    writeln!(
3662        code,
3663        "    /// Run the forward pass for a single token at the current position."
3664    )?;
3665    writeln!(code, "    ///")?;
3666    writeln!(
3667        code,
3668        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3669    )?;
3670    writeln!(code, "    ///")?;
3671    writeln!(
3672        code,
3673        "    /// All GPU operations are encoded into a single command buffer and"
3674    )?;
3675    writeln!(
3676        code,
3677        "    /// committed once at the end, avoiding per-operation synchronization."
3678    )?;
3679    writeln!(
3680        code,
3681        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3682    )?;
3683    writeln!(
3684        code,
3685        "        // Wait for any pending prefill command buffer"
3686    )?;
3687    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3688    writeln!(code, "            prev.wait_until_completed();")?;
3689    writeln!(code, "        }}")?;
3690    writeln!(code)?;
3691    writeln!(code, "        let pos = self.pos;")?;
3692    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3693    writeln!(code)?;
3694
3695    // Single compute encoder for the entire forward pass — no blit encoder
3696    // transitions. Copy operations use compute copy kernels instead of blits.
3697    let matmul_fn = if is_q8 {
3698        "dispatch_matmul_q8"
3699    } else if is_q4 {
3700        "dispatch_matmul_q4"
3701    } else {
3702        "dispatch_matmul"
3703    };
3704
3705    writeln!(
3706        code,
3707        "        // Single compute encoder for the entire forward pass (no blit transitions)"
3708    )?;
3709    writeln!(code, "        {{")?;
3710    writeln!(
3711        code,
3712        "            let enc = cmd.new_compute_command_encoder();"
3713    )?;
3714    writeln!(code)?;
3715
3716    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
3717    writeln!(
3718        code,
3719        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3720    )?;
3721    writeln!(
3722        code,
3723        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3724    )?;
3725    writeln!(
3726        code,
3727        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3728    )?;
3729    writeln!(
3730        code,
3731        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
3732        hidden * 4,
3733    )?;
3734    writeln!(code, "            unsafe {{")?;
3735    writeln!(
3736        code,
3737        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
3738    )?;
3739    writeln!(
3740        code,
3741        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3742    )?;
3743    writeln!(
3744        code,
3745        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
3746    )?;
3747    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
3748    writeln!(
3749        code,
3750        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3751    )?;
3752    writeln!(code, "                    hidden_ptr,")?;
3753    writeln!(code, "                    HIDDEN_SIZE,")?;
3754    writeln!(code, "                );")?;
3755    writeln!(
3756        code,
3757        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3758    )?;
3759    writeln!(code, "            }}")?;
3760    writeln!(code)?;
3761
3762    // 2. Transformer layers
3763    writeln!(code, "            // 2. Transformer layers")?;
3764    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3765    writeln!(code)?;
3766    let q_byte_offset = 0usize;
3767    let k_byte_offset = hidden * 4;
3768    let v_byte_offset = (hidden + kv_dim) * 4;
3769
3770    writeln!(
3771        code,
3772        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
3773    )?;
3774    writeln!(
3775        code,
3776        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3777    )?;
3778    writeln!(
3779        code,
3780        "                // Fused Q+K+V matmul: single dispatch for all three projections"
3781    )?;
3782    writeln!(
3783        code,
3784        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3785    )?;
3786    if config.qkv_bias {
3787        writeln!(
3788            code,
3789            "                // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
3790        )?;
3791        writeln!(
3792            code,
3793            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
3794        )?;
3795    }
3796    writeln!(
3797        code,
3798        "                // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
3799    )?;
3800    writeln!(
3801        code,
3802        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3803    )?;
3804    writeln!(
3805        code,
3806        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3807    )?;
3808    writeln!(code)?;
3809    writeln!(
3810        code,
3811        "                // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
3812    )?;
3813    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
3814    writeln!(
3815        code,
3816        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3817    )?;
3818    writeln!(
3819        code,
3820        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3821    )?;
3822    writeln!(code)?;
3823    writeln!(
3824        code,
3825        "                // Attention using Q from qkv_buf (offset 0)"
3826    )?;
3827    writeln!(
3828        code,
3829        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3830    )?;
3831    writeln!(
3832        code,
3833        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3834    )?;
3835    writeln!(
3836        code,
3837        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3838    )?;
3839    writeln!(
3840        code,
3841        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
3842    )?;
3843    writeln!(
3844        code,
3845        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3846    )?;
3847    writeln!(
3848        code,
3849        "                // Fused gate+up matmul: single dispatch for both projections"
3850    )?;
3851    writeln!(
3852        code,
3853        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3854    )?;
3855    writeln!(
3856        code,
3857        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3858    )?;
3859    writeln!(
3860        code,
3861        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3862    )?;
3863    writeln!(
3864        code,
3865        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3866    )?;
3867    writeln!(code, "            }}")?;
3868    writeln!(code)?;
3869
3870    // 3. Final RMS norm + logits
3871    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
3872    writeln!(
3873        code,
3874        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
3875    )?;
3876    writeln!(
3877        code,
3878        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3879    )?;
3880    writeln!(code)?;
3881    writeln!(code, "            enc.end_encoding();")?;
3882    writeln!(code, "        }}")?;
3883    writeln!(code)?;
3884
3885    // 5. Single commit + wait, then read back logits
3886    writeln!(
3887        code,
3888        "        // 5. Commit all GPU work and wait for completion"
3889    )?;
3890    writeln!(code, "        cmd.commit();")?;
3891    writeln!(code, "        cmd.wait_until_completed();")?;
3892    writeln!(code)?;
3893    writeln!(code, "        // 6. Read back logits from GPU")?;
3894    writeln!(code, "        let logits = unsafe {{")?;
3895    writeln!(
3896        code,
3897        "            let ptr = self.logits_buf.contents() as *const f32;"
3898    )?;
3899    writeln!(
3900        code,
3901        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3902    )?;
3903    writeln!(code, "        }};")?;
3904    writeln!(code)?;
3905    writeln!(code, "        self.pos += 1;")?;
3906    writeln!(code, "        logits")?;
3907    writeln!(code, "    }}")?;
3908    writeln!(code)?;
3909
3910    // ── forward_profile: instrumented forward with per-operation timing ──
3911    writeln!(
3912        code,
3913        "    /// Profiling forward pass that prints per-stage GPU timing."
3914    )?;
3915    writeln!(code, "    ///")?;
3916    writeln!(
3917        code,
3918        "    /// Each stage is committed and waited on separately so that GPU timestamps"
3919    )?;
3920    writeln!(
3921        code,
3922        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
3923    )?;
3924    writeln!(
3925        code,
3926        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
3927    )?;
3928    writeln!(
3929        code,
3930        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
3931    )?;
3932    writeln!(code, "        use std::time::Instant;")?;
3933    writeln!(code)?;
3934    writeln!(
3935        code,
3936        "        // Wait for any pending prefill command buffer"
3937    )?;
3938    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3939    writeln!(code, "            prev.wait_until_completed();")?;
3940    writeln!(code, "        }}")?;
3941    writeln!(code)?;
3942    writeln!(code, "        let pos = self.pos;")?;
3943    writeln!(code)?;
3944
3945    // Stage: embedding (CPU, no GPU)
3946    writeln!(
3947        code,
3948        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
3949    )?;
3950    writeln!(code, "        let t_embed = Instant::now();")?;
3951    writeln!(code, "        unsafe {{")?;
3952    writeln!(
3953        code,
3954        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
3955    )?;
3956    writeln!(
3957        code,
3958        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3959    )?;
3960    writeln!(
3961        code,
3962        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
3963    )?;
3964    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
3965    writeln!(
3966        code,
3967        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3968    )?;
3969    writeln!(code, "                hidden_ptr,")?;
3970    writeln!(code, "                HIDDEN_SIZE,")?;
3971    writeln!(code, "            );")?;
3972    writeln!(
3973        code,
3974        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3975    )?;
3976    writeln!(code, "        }}")?;
3977    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
3978    writeln!(code)?;
3979
3980    // Stage: Transformer layers (all together on GPU)
3981    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
3982    writeln!(code, "        let t_layers = Instant::now();")?;
3983    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3984    writeln!(code, "        {{")?;
3985    writeln!(
3986        code,
3987        "            let enc = cmd.new_compute_command_encoder();"
3988    )?;
3989    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3990    writeln!(
3991        code,
3992        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3993    )?;
3994    writeln!(
3995        code,
3996        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3997    )?;
3998    if config.qkv_bias {
3999        writeln!(
4000            code,
4001            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4002        )?;
4003    }
4004    writeln!(
4005        code,
4006        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4007    )?;
4008    writeln!(
4009        code,
4010        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4011    )?;
4012    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
4013    writeln!(
4014        code,
4015        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4016    )?;
4017    writeln!(
4018        code,
4019        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4020    )?;
4021    writeln!(
4022        code,
4023        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4024    )?;
4025    writeln!(
4026        code,
4027        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4028    )?;
4029    writeln!(
4030        code,
4031        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4032    )?;
4033    writeln!(
4034        code,
4035        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4036    )?;
4037    writeln!(
4038        code,
4039        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4040    )?;
4041    writeln!(
4042        code,
4043        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4044    )?;
4045    writeln!(
4046        code,
4047        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4048    )?;
4049    writeln!(
4050        code,
4051        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4052    )?;
4053    writeln!(code, "            }}")?;
4054    writeln!(code, "            enc.end_encoding();")?;
4055    writeln!(code, "        }}")?;
4056    writeln!(code, "        cmd.commit();")?;
4057    writeln!(code, "        cmd.wait_until_completed();")?;
4058    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
4059    writeln!(code)?;
4060
4061    // Stage: Final norm + logits
4062    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
4063    writeln!(code, "        let t_logits = Instant::now();")?;
4064    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
4065    writeln!(code, "        {{")?;
4066    writeln!(
4067        code,
4068        "            let enc = cmd2.new_compute_command_encoder();"
4069    )?;
4070    writeln!(
4071        code,
4072        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4073    )?;
4074    writeln!(
4075        code,
4076        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4077    )?;
4078    writeln!(code, "            enc.end_encoding();")?;
4079    writeln!(code, "        }}")?;
4080    writeln!(code, "        cmd2.commit();")?;
4081    writeln!(code, "        cmd2.wait_until_completed();")?;
4082    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
4083    writeln!(code)?;
4084
4085    // Print profile results
4086    writeln!(
4087        code,
4088        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4089    )?;
4090    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
4091    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
4092    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
4093    writeln!(
4094        code,
4095        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4096    )?;
4097    writeln!(code)?;
4098
4099    // Read back logits
4100    writeln!(code, "        let logits = unsafe {{")?;
4101    writeln!(
4102        code,
4103        "            let ptr = self.logits_buf.contents() as *const f32;"
4104    )?;
4105    writeln!(
4106        code,
4107        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4108    )?;
4109    writeln!(code, "        }};")?;
4110    writeln!(code)?;
4111    writeln!(code, "        self.pos += 1;")?;
4112    writeln!(code, "        logits")?;
4113    writeln!(code, "    }}")?;
4114    writeln!(code)?;
4115
4116    // ── forward_prefill: single-token async forward (backward compat) ──
4117    writeln!(
4118        code,
4119        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
4120    )?;
4121    writeln!(code, "    ///")?;
4122    writeln!(
4123        code,
4124        "    /// Commits the command buffer without waiting, enabling double-buffered"
4125    )?;
4126    writeln!(
4127        code,
4128        "    /// execution: GPU processes token N while CPU encodes token N+1."
4129    )?;
4130    writeln!(
4131        code,
4132        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
4133    )?;
4134    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
4135    writeln!(code, "    }}")?;
4136    writeln!(code)?;
4137
4138    // ── forward_prefill_batch: batched prefill for multiple tokens ──
4139    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
4140    let batch_matmul_fn = if is_q8 {
4141        "dispatch_matmul_q8_batch"
4142    } else if is_q4 {
4143        "dispatch_matmul_q4_batch"
4144    } else {
4145        "dispatch_matmul_batch"
4146    };
4147
4148    writeln!(
4149        code,
4150        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4151    )?;
4152    writeln!(code, "    ///")?;
4153    writeln!(
4154        code,
4155        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4156    )?;
4157    writeln!(
4158        code,
4159        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
4160    )?;
4161    writeln!(
4162        code,
4163        "    /// This provides significant speedup during prompt prefill."
4164    )?;
4165    writeln!(
4166        code,
4167        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4168    )?;
4169    writeln!(code, "        if tokens.is_empty() {{ return; }}")?;
4170    writeln!(
4171        code,
4172        "        // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4173    )?;
4174    writeln!(
4175        code,
4176        "        // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4177    )?;
4178    writeln!(
4179        code,
4180        "        // longer than that must be processed iteratively.  The KV cache"
4181    )?;
4182    writeln!(
4183        code,
4184        "        // carries state across chunks via self.pos."
4185    )?;
4186    writeln!(code, "        for chunk in tokens.chunks(MAX_BATCH_SIZE) {{")?;
4187    writeln!(code, "        let m = chunk.len();")?;
4188    writeln!(code, "        if m == 0 {{ continue; }}")?;
4189    writeln!(code, "        let start_pos = self.pos;")?;
4190    writeln!(code)?;
4191    writeln!(code, "        // Wait for any pending command buffer")?;
4192    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4193    writeln!(code, "            prev.wait_until_completed();")?;
4194    writeln!(code, "        }}")?;
4195    writeln!(code)?;
4196
4197    // Upload token IDs and positions to GPU
4198    writeln!(
4199        code,
4200        "        // Upload token IDs and positions to GPU buffers"
4201    )?;
4202    writeln!(code, "        unsafe {{")?;
4203    writeln!(
4204        code,
4205        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4206    )?;
4207    writeln!(
4208        code,
4209        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4210    )?;
4211    writeln!(code, "            for i in 0..m {{")?;
4212    writeln!(code, "                *tok_ptr.add(i) = chunk[i];")?;
4213    writeln!(
4214        code,
4215        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
4216    )?;
4217    writeln!(code, "            }}")?;
4218    writeln!(code, "        }}")?;
4219    writeln!(code)?;
4220
4221    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4222    writeln!(code, "        {{")?;
4223    writeln!(
4224        code,
4225        "            let enc = cmd.new_compute_command_encoder();"
4226    )?;
4227    writeln!(code)?;
4228
4229    // 1. Batch embedding lookup
4230    writeln!(
4231        code,
4232        "            // 1. Batch embedding lookup: copy all token embeddings at once"
4233    )?;
4234    writeln!(
4235        code,
4236        "            self.dispatch_copy_embedding_batch(&enc, m);"
4237    )?;
4238    // Copy batch_hidden -> batch_residual
4239    writeln!(
4240        code,
4241        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4242    )?;
4243    writeln!(code)?;
4244
4245    // 2. Transformer layers
4246    writeln!(code, "            // 2. Transformer layers")?;
4247    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4248    writeln!(code)?;
4249
4250    // Batch RMS norm: residual -> hidden (batched)
4251    writeln!(
4252        code,
4253        "                // Batch RMS norm: batch_residual -> batch_hidden"
4254    )?;
4255    writeln!(
4256        code,
4257        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4258    )?;
4259
4260    // Batch QKV matmul
4261    writeln!(
4262        code,
4263        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4264    )?;
4265    writeln!(
4266        code,
4267        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4268    )?;
4269    if config.qkv_bias {
4270        writeln!(
4271            code,
4272            "                // Qwen2: broadcast-add QKV bias across all M tokens."
4273        )?;
4274        writeln!(
4275            code,
4276            "                self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4277        )?;
4278    }
4279    writeln!(code)?;
4280
4281    // Fused RoPE on Q+K portions in a single dispatch
4282    let k_float_offset = hidden;
4283    writeln!(
4284        code,
4285        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4286    )?;
4287    writeln!(
4288        code,
4289        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4290    )?;
4291    writeln!(code)?;
4292
4293    // Fused KV cache update: copy both K and V in a single dispatch
4294    let v_float_offset = hidden + kv_dim;
4295    writeln!(
4296        code,
4297        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4298    )?;
4299    writeln!(
4300        code,
4301        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4302    )?;
4303    writeln!(code)?;
4304
4305    // Batched causal attention: ONE dispatch for all M tokens
4306    writeln!(
4307        code,
4308        "                // Batched causal attention: one dispatch for all M tokens"
4309    )?;
4310    writeln!(
4311        code,
4312        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4313    )?;
4314    writeln!(code)?;
4315
4316    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4317    writeln!(code, "                // Batched O projection")?;
4318    writeln!(
4319        code,
4320        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4321    )?;
4322    writeln!(code)?;
4323
4324    // Batch add: residual += attn_proj for all tokens
4325    writeln!(
4326        code,
4327        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4328    )?;
4329    writeln!(code)?;
4330
4331    // Batch FFN
4332    writeln!(
4333        code,
4334        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4335    )?;
4336    writeln!(
4337        code,
4338        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4339    )?;
4340    writeln!(
4341        code,
4342        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4343    )?;
4344    writeln!(
4345        code,
4346        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4347    )?;
4348    writeln!(
4349        code,
4350        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4351    )?;
4352    writeln!(
4353        code,
4354        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4355    )?;
4356    writeln!(code, "            }}")?;
4357    writeln!(code)?;
4358
4359    // Copy last token's residual to single-token residual_buf for next forward() call
4360    writeln!(
4361        code,
4362        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4363    )?;
4364    writeln!(
4365        code,
4366        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4367    )?;
4368    writeln!(code)?;
4369    writeln!(code, "            enc.end_encoding();")?;
4370    writeln!(code, "        }}")?;
4371    writeln!(code)?;
4372
4373    writeln!(code, "        cmd.commit();")?;
4374    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4375    writeln!(code, "        self.pos += m;")?;
4376    writeln!(code, "        }}  // end for chunk")?;
4377    writeln!(code, "    }}")?;
4378    writeln!(code)?;
4379
4380    // ── reset() — rewind KV cache position for new inference requests ──
4381    writeln!(
4382        code,
4383        "    /// Reset the model state for a new inference request."
4384    )?;
4385    writeln!(code, "    pub fn reset(&mut self) {{")?;
4386    writeln!(code, "        self.pos = 0;")?;
4387    writeln!(code, "        self.prev_cmd = None;")?;
4388    writeln!(code, "    }}")?;
4389    writeln!(code)?;
4390
4391    // ── Private dispatch helpers (all take a shared compute encoder) ──
4392    writeln!(
4393        code,
4394        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4395    )?;
4396    writeln!(
4397        code,
4398        "    // These methods set pipeline state + buffers + dispatch on an existing"
4399    )?;
4400    writeln!(
4401        code,
4402        "    // encoder, avoiding per-operation encoder creation overhead."
4403    )?;
4404    writeln!(code)?;
4405
4406    // dispatch_rms_norm
4407    writeln!(
4408        code,
4409        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4410    )?;
4411    writeln!(
4412        code,
4413        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4414    )?;
4415    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4416    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4417    writeln!(
4418        code,
4419        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4420    )?;
4421    writeln!(
4422        code,
4423        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4424    )?;
4425    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4426    writeln!(
4427        code,
4428        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4429    )?;
4430    writeln!(
4431        code,
4432        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4433    )?;
4434    writeln!(
4435        code,
4436        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4437    )?;
4438    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4439    writeln!(
4440        code,
4441        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4442    )?;
4443    writeln!(
4444        code,
4445        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4446    )?;
4447    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4448    writeln!(code, "    }}")?;
4449    writeln!(code)?;
4450
4451    // dispatch_matmul
4452    writeln!(
4453        code,
4454        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4455    )?;
4456    writeln!(
4457        code,
4458        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4459    )?;
4460    writeln!(code, "        let r: u32 = rows as u32;")?;
4461    writeln!(code, "        let c: u32 = cols as u32;")?;
4462    writeln!(
4463        code,
4464        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4465    )?;
4466    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4467    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4468    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4469    writeln!(
4470        code,
4471        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4472    )?;
4473    writeln!(
4474        code,
4475        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4476    )?;
4477    writeln!(
4478        code,
4479        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4480    )?;
4481    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4482    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4483    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4484    writeln!(
4485        code,
4486        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4487    )?;
4488    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4489    writeln!(code, "    }}")?;
4490    writeln!(code)?;
4491
4492    // dispatch_matmul_q8
4493    writeln!(
4494        code,
4495        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4496    )?;
4497    writeln!(
4498        code,
4499        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4500    )?;
4501    writeln!(
4502        code,
4503        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4504    )?;
4505    writeln!(code, "        let r: u32 = rows as u32;")?;
4506    writeln!(code, "        let c: u32 = cols as u32;")?;
4507    writeln!(
4508        code,
4509        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4510    )?;
4511    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4512    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4513    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4514    writeln!(
4515        code,
4516        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4517    )?;
4518    writeln!(
4519        code,
4520        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4521    )?;
4522    writeln!(
4523        code,
4524        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4525    )?;
4526    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4527    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4528    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4529    writeln!(
4530        code,
4531        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4532    )?;
4533    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4534    writeln!(code, "    }}")?;
4535    writeln!(code)?;
4536
4537    // dispatch_matmul_q4
4538    writeln!(
4539        code,
4540        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4541    )?;
4542    writeln!(
4543        code,
4544        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4545    )?;
4546    writeln!(
4547        code,
4548        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4549    )?;
4550    writeln!(code, "        let r: u32 = rows as u32;")?;
4551    writeln!(code, "        let c: u32 = cols as u32;")?;
4552    writeln!(
4553        code,
4554        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4555    )?;
4556    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4557    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4558    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4559    writeln!(
4560        code,
4561        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4562    )?;
4563    writeln!(
4564        code,
4565        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4566    )?;
4567    writeln!(
4568        code,
4569        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4570    )?;
4571    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4572    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4573    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4574    writeln!(
4575        code,
4576        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4577    )?;
4578    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4579    writeln!(code, "    }}")?;
4580    writeln!(code)?;
4581
4582    // dispatch_rope
4583    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4584    writeln!(
4585        code,
4586        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4587    )?;
4588    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4589    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4590    writeln!(code, "        let p: u32 = pos as u32;")?;
4591    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4592    writeln!(
4593        code,
4594        "        let total_pairs = num_heads * (head_dim / 2);"
4595    )?;
4596    writeln!(
4597        code,
4598        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4599    )?;
4600    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4601    writeln!(
4602        code,
4603        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4604    )?;
4605    writeln!(
4606        code,
4607        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4608    )?;
4609    writeln!(
4610        code,
4611        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4612    )?;
4613    writeln!(
4614        code,
4615        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4616    )?;
4617    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4618    writeln!(
4619        code,
4620        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4621    )?;
4622    writeln!(
4623        code,
4624        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4625    )?;
4626    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4627    writeln!(code, "    }}")?;
4628    writeln!(code)?;
4629
4630    // dispatch_rope_offset
4631    writeln!(
4632        code,
4633        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4634    )?;
4635    writeln!(
4636        code,
4637        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4638    )?;
4639    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4640    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4641    writeln!(code, "        let p: u32 = pos as u32;")?;
4642    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4643    writeln!(
4644        code,
4645        "        let total_pairs = num_heads * (head_dim / 2);"
4646    )?;
4647    writeln!(
4648        code,
4649        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4650    )?;
4651    writeln!(
4652        code,
4653        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4654    )?;
4655    writeln!(
4656        code,
4657        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4658    )?;
4659    writeln!(
4660        code,
4661        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4662    )?;
4663    writeln!(
4664        code,
4665        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4666    )?;
4667    writeln!(
4668        code,
4669        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4670    )?;
4671    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4672    writeln!(
4673        code,
4674        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4675    )?;
4676    writeln!(
4677        code,
4678        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4679    )?;
4680    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4681    writeln!(code, "    }}")?;
4682    writeln!(code)?;
4683
4684    // dispatch_attention
4685    writeln!(
4686        code,
4687        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4688    )?;
4689    writeln!(
4690        code,
4691        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4692    )?;
4693    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4694    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4695    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4696    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4697    writeln!(
4698        code,
4699        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4700    )?;
4701    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4702    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4703    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4704    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4705    writeln!(
4706        code,
4707        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4708    )?;
4709    writeln!(
4710        code,
4711        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4712    )?;
4713    writeln!(
4714        code,
4715        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4716    )?;
4717    writeln!(
4718        code,
4719        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4720    )?;
4721    writeln!(
4722        code,
4723        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4724    )?;
4725    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4726    writeln!(
4727        code,
4728        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4729    )?;
4730    writeln!(
4731        code,
4732        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4733    )?;
4734    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4735    writeln!(code, "    }}")?;
4736    writeln!(code)?;
4737
4738    // dispatch_attention_offset
4739    writeln!(
4740        code,
4741        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4742    )?;
4743    writeln!(
4744        code,
4745        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4746    )?;
4747    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4748    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4749    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4750    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4751    writeln!(
4752        code,
4753        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4754    )?;
4755    writeln!(
4756        code,
4757        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
4758    )?;
4759    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4760    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4761    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4762    writeln!(
4763        code,
4764        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4765    )?;
4766    writeln!(
4767        code,
4768        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4769    )?;
4770    writeln!(
4771        code,
4772        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4773    )?;
4774    writeln!(
4775        code,
4776        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4777    )?;
4778    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4779    writeln!(
4780        code,
4781        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4782    )?;
4783    writeln!(
4784        code,
4785        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4786    )?;
4787    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4788    writeln!(code, "    }}")?;
4789    writeln!(code)?;
4790
4791    // dispatch_silu_mul
4792    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
4793    writeln!(
4794        code,
4795        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
4796    )?;
4797    writeln!(code, "        let count: u32 = n as u32;")?;
4798    writeln!(
4799        code,
4800        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
4801    )?;
4802    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
4803    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
4804    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4805    writeln!(
4806        code,
4807        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4808    )?;
4809    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4810    writeln!(
4811        code,
4812        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4813    )?;
4814    writeln!(
4815        code,
4816        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4817    )?;
4818    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4819    writeln!(code, "    }}")?;
4820    writeln!(code)?;
4821
4822    // dispatch_silu_mul_fused
4823    writeln!(
4824        code,
4825        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
4826    )?;
4827    writeln!(
4828        code,
4829        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
4830    )?;
4831    writeln!(
4832        code,
4833        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
4834    )?;
4835    writeln!(code, "        let count: u32 = n as u32;")?;
4836    writeln!(
4837        code,
4838        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
4839    )?;
4840    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
4841    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
4842    writeln!(
4843        code,
4844        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4845    )?;
4846    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4847    writeln!(
4848        code,
4849        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4850    )?;
4851    writeln!(
4852        code,
4853        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4854    )?;
4855    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4856    writeln!(code, "    }}")?;
4857    writeln!(code)?;
4858
4859    // dispatch_copy (simple src -> dst copy via compute kernel)
4860    writeln!(
4861        code,
4862        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
4863    )?;
4864    writeln!(
4865        code,
4866        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
4867    )?;
4868    writeln!(code, "        let n: u32 = count as u32;")?;
4869    writeln!(
4870        code,
4871        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4872    )?;
4873    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4874    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4875    writeln!(
4876        code,
4877        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4878    )?;
4879    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4880    writeln!(
4881        code,
4882        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4883    )?;
4884    writeln!(
4885        code,
4886        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4887    )?;
4888    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4889    writeln!(code, "    }}")?;
4890    writeln!(code)?;
4891
4892    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
4893    writeln!(
4894        code,
4895        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
4896    )?;
4897    writeln!(
4898        code,
4899        "    /// Used for embedding table lookup (copy a specific row)."
4900    )?;
4901    writeln!(
4902        code,
4903        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
4904    )?;
4905    writeln!(code, "        let off: u32 = src_offset as u32;")?;
4906    writeln!(code, "        let n: u32 = count as u32;")?;
4907    writeln!(
4908        code,
4909        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
4910    )?;
4911    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4912    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4913    writeln!(
4914        code,
4915        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
4916    )?;
4917    writeln!(
4918        code,
4919        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4920    )?;
4921    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4922    writeln!(
4923        code,
4924        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4925    )?;
4926    writeln!(
4927        code,
4928        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4929    )?;
4930    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4931    writeln!(code, "    }}")?;
4932    writeln!(code)?;
4933
4934    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
4935    writeln!(
4936        code,
4937        "    /// Dispatch copy from source at byte offset to destination at float offset."
4938    )?;
4939    writeln!(
4940        code,
4941        "    /// Used for KV cache updates from fused QKV buffer."
4942    )?;
4943    writeln!(
4944        code,
4945        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4946    )?;
4947    writeln!(code, "        let n: u32 = count as u32;")?;
4948    writeln!(
4949        code,
4950        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4951    )?;
4952    writeln!(
4953        code,
4954        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
4955    )?;
4956    writeln!(
4957        code,
4958        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4959    )?;
4960    writeln!(
4961        code,
4962        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4963    )?;
4964    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4965    writeln!(
4966        code,
4967        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4968    )?;
4969    writeln!(
4970        code,
4971        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4972    )?;
4973    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4974    writeln!(code, "    }}")?;
4975    writeln!(code)?;
4976
4977    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
4978    writeln!(
4979        code,
4980        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
4981    )?;
4982    writeln!(
4983        code,
4984        "    /// Used for KV cache updates (write to a specific position in the cache)."
4985    )?;
4986    writeln!(
4987        code,
4988        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
4989    )?;
4990    writeln!(code, "        let n: u32 = count as u32;")?;
4991    writeln!(
4992        code,
4993        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4994    )?;
4995    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4996    writeln!(
4997        code,
4998        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
4999    )?;
5000    writeln!(
5001        code,
5002        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5003    )?;
5004    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5005    writeln!(
5006        code,
5007        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5008    )?;
5009    writeln!(
5010        code,
5011        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5012    )?;
5013    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5014    writeln!(code, "    }}")?;
5015    writeln!(code)?;
5016
5017    // dispatch_add_inplace (residual connection, no blit needed)
5018    writeln!(
5019        code,
5020        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5021    )?;
5022    writeln!(
5023        code,
5024        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5025    )?;
5026    writeln!(code, "        let count: u32 = n as u32;")?;
5027    writeln!(
5028        code,
5029        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5030    )?;
5031    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5032    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5033    writeln!(
5034        code,
5035        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5036    )?;
5037    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5038    writeln!(
5039        code,
5040        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5041    )?;
5042    writeln!(
5043        code,
5044        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5045    )?;
5046    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5047    writeln!(code, "    }}")?;
5048    writeln!(code)?;
5049
5050    // ── Batched prefill dispatch helpers ──
5051    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
5052    writeln!(code)?;
5053
5054    // dispatch_copy_embedding_batch
5055    writeln!(
5056        code,
5057        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
5058    )?;
5059    writeln!(
5060        code,
5061        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5062    )?;
5063    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
5064    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5065    writeln!(
5066        code,
5067        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5068    )?;
5069    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5070    writeln!(
5071        code,
5072        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5073    )?;
5074    writeln!(
5075        code,
5076        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5077    )?;
5078    writeln!(
5079        code,
5080        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5081    )?;
5082    writeln!(
5083        code,
5084        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5085    )?;
5086    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
5087    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5088    writeln!(
5089        code,
5090        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5091    )?;
5092    writeln!(
5093        code,
5094        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5095    )?;
5096    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5097    writeln!(code, "    }}")?;
5098    writeln!(code)?;
5099
5100    // dispatch_rms_norm_batch
5101    writeln!(
5102        code,
5103        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
5104    )?;
5105    writeln!(
5106        code,
5107        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5108    )?;
5109    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
5110    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
5111    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5112    writeln!(
5113        code,
5114        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5115    )?;
5116    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
5117    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
5118    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5119    writeln!(
5120        code,
5121        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5122    )?;
5123    writeln!(
5124        code,
5125        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5126    )?;
5127    writeln!(
5128        code,
5129        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5130    )?;
5131    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5132    writeln!(
5133        code,
5134        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5135    )?;
5136    writeln!(
5137        code,
5138        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5139    )?;
5140    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5141    writeln!(code, "    }}")?;
5142    writeln!(code)?;
5143
5144    // dispatch_matmul_batch (f32)
5145    writeln!(
5146        code,
5147        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5148    )?;
5149    writeln!(
5150        code,
5151        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5152    )?;
5153    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5154    writeln!(code, "        let r: u32 = rows as u32;")?;
5155    writeln!(code, "        let c: u32 = cols as u32;")?;
5156    writeln!(
5157        code,
5158        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5159    )?;
5160    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5161    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5162    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5163    writeln!(
5164        code,
5165        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5166    )?;
5167    writeln!(
5168        code,
5169        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5170    )?;
5171    writeln!(
5172        code,
5173        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5174    )?;
5175    writeln!(
5176        code,
5177        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
5178    )?;
5179    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5180    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5181    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5182    writeln!(
5183        code,
5184        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5185    )?;
5186    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5187    writeln!(code, "    }}")?;
5188    writeln!(code)?;
5189
5190    // dispatch_matmul_q8_batch
5191    writeln!(
5192        code,
5193        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5194    )?;
5195    writeln!(code, "    ///")?;
5196    writeln!(
5197        code,
5198        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5199    )?;
5200    writeln!(
5201        code,
5202        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5203    )?;
5204    writeln!(
5205        code,
5206        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5207    )?;
5208    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5209    writeln!(code, "        let r: u32 = rows as u32;")?;
5210    writeln!(code, "        let c: u32 = cols as u32;")?;
5211    writeln!(
5212        code,
5213        "        // Tile sizes must match the Metal shader constants."
5214    )?;
5215    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
5216    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
5217    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
5218    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
5219    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
5220    writeln!(
5221        code,
5222        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
5223    )?;
5224    writeln!(
5225        code,
5226        "        // Prefer the large 32×32 tile when the problem supports it — halves"
5227    )?;
5228    writeln!(
5229        code,
5230        "        // dispatch count and reuses each weight load across 32 tokens."
5231    )?;
5232    writeln!(
5233        code,
5234        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5235    )?;
5236    writeln!(
5237        code,
5238        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5239    )?;
5240    writeln!(
5241        code,
5242        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
5243    )?;
5244    writeln!(
5245        code,
5246        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5247    )?;
5248    writeln!(
5249        code,
5250        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
5251    )?;
5252    writeln!(
5253        code,
5254        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5255    )?;
5256    writeln!(
5257        code,
5258        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5259    )?;
5260    writeln!(
5261        code,
5262        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5263    )?;
5264    writeln!(
5265        code,
5266        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5267    )?;
5268    writeln!(code, "            let use_h4 = cols >= 2048;")?;
5269    writeln!(code, "            let pipe = if use_h4 {{")?;
5270    writeln!(code, "                if num_tokens >= 256 {{")?;
5271    writeln!(code, "                    &self.matmul_q8_mma32_hh4_pipeline")?;
5272    writeln!(code, "                }} else {{")?;
5273    writeln!(code, "                    &self.matmul_q8_mma32_h4_pipeline")?;
5274    writeln!(code, "                }}")?;
5275    writeln!(code, "            }} else {{")?;
5276    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
5277    writeln!(code, "            }};")?;
5278    writeln!(
5279        code,
5280        "            enc.set_compute_pipeline_state(pipe);"
5281    )?;
5282    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5283    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5284    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5285    writeln!(
5286        code,
5287        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5288    )?;
5289    writeln!(
5290        code,
5291        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5292    )?;
5293    writeln!(
5294        code,
5295        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5296    )?;
5297    writeln!(
5298        code,
5299        "            let row_tgs = rows / MMA32_ROW_TILE;"
5300    )?;
5301    writeln!(
5302        code,
5303        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5304    )?;
5305    writeln!(
5306        code,
5307        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5308    )?;
5309    writeln!(
5310        code,
5311        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5312    )?;
5313    writeln!(
5314        code,
5315        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5316    )?;
5317    writeln!(
5318        code,
5319        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5320    )?;
5321    writeln!(
5322        code,
5323        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5324    )?;
5325    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5326    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5327    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5328    writeln!(
5329        code,
5330        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5331    )?;
5332    writeln!(
5333        code,
5334        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5335    )?;
5336    writeln!(
5337        code,
5338        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5339    )?;
5340    writeln!(
5341        code,
5342        "            let row_tgs = rows / MMA_ROW_TILE;"
5343    )?;
5344    writeln!(
5345        code,
5346        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5347    )?;
5348    writeln!(
5349        code,
5350        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5351    )?;
5352    writeln!(
5353        code,
5354        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5355    )?;
5356    writeln!(
5357        code,
5358        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5359    )?;
5360    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5361    writeln!(
5362        code,
5363        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5364    )?;
5365    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5366    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5367    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5368    writeln!(
5369        code,
5370        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5371    )?;
5372    writeln!(
5373        code,
5374        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5375    )?;
5376    writeln!(
5377        code,
5378        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5379    )?;
5380    writeln!(
5381        code,
5382        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5383    )?;
5384    writeln!(
5385        code,
5386        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5387    )?;
5388    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5389    writeln!(
5390        code,
5391        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5392    )?;
5393    writeln!(
5394        code,
5395        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5396    )?;
5397    writeln!(code, "        }} else {{")?;
5398    writeln!(
5399        code,
5400        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5401    )?;
5402    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5403    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5404    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5405    writeln!(
5406        code,
5407        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5408    )?;
5409    writeln!(
5410        code,
5411        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5412    )?;
5413    writeln!(
5414        code,
5415        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5416    )?;
5417    writeln!(
5418        code,
5419        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5420    )?;
5421    writeln!(
5422        code,
5423        "            let num_tg = (row_tgs * num_tokens) as u64;"
5424    )?;
5425    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5426    writeln!(
5427        code,
5428        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5429    )?;
5430    writeln!(
5431        code,
5432        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5433    )?;
5434    writeln!(code, "        }}")?;
5435    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5436    writeln!(code, "    }}")?;
5437    writeln!(code)?;
5438
5439    // dispatch_matmul_q4_batch
5440    writeln!(
5441        code,
5442        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5443    )?;
5444    writeln!(
5445        code,
5446        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5447    )?;
5448    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5449    writeln!(code, "        let r: u32 = rows as u32;")?;
5450    writeln!(code, "        let c: u32 = cols as u32;")?;
5451    writeln!(
5452        code,
5453        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5454    )?;
5455    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5456    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5457    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5458    writeln!(
5459        code,
5460        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5461    )?;
5462    writeln!(
5463        code,
5464        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5465    )?;
5466    writeln!(
5467        code,
5468        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5469    )?;
5470    writeln!(
5471        code,
5472        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5473    )?;
5474    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5475    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5476    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5477    writeln!(
5478        code,
5479        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5480    )?;
5481    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5482    writeln!(code, "    }}")?;
5483    writeln!(code)?;
5484
5485    // dispatch_add_bias_batch — Qwen2 QKV bias broadcast-add after fused qkv matmul.
5486    if config.qkv_bias {
5487        writeln!(
5488            code,
5489            "    /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5490        )?;
5491        writeln!(
5492            code,
5493            "    fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5494        )?;
5495        writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5496        writeln!(code, "        let r: u32 = rows as u32;")?;
5497        writeln!(
5498            code,
5499            "        enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5500        )?;
5501        writeln!(code, "        enc.set_buffer(0, Some(out), 0);")?;
5502        writeln!(code, "        enc.set_buffer(1, Some(bias), 0);")?;
5503        writeln!(
5504            code,
5505            "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5506        )?;
5507        writeln!(
5508            code,
5509            "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5510        )?;
5511        writeln!(code, "        let total = (num_tokens * rows) as u64;")?;
5512        writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5513        writeln!(
5514            code,
5515            "        let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5516        )?;
5517        writeln!(
5518            code,
5519            "        enc.dispatch_thread_groups(grid_size, tg_size);"
5520        )?;
5521        writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5522        writeln!(code, "    }}")?;
5523        writeln!(code)?;
5524    }
5525
5526    // dispatch_rope_batch
5527    writeln!(
5528        code,
5529        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5530    )?;
5531    writeln!(
5532        code,
5533        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5534    )?;
5535    writeln!(
5536        code,
5537        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5538    )?;
5539    writeln!(
5540        code,
5541        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5542    )?;
5543    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5544    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5545    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5546    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5547    writeln!(
5548        code,
5549        "        let pairs_per_token = num_heads * (head_dim / 2);"
5550    )?;
5551    writeln!(
5552        code,
5553        "        let total_pairs = num_tokens * pairs_per_token;"
5554    )?;
5555    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5556    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5557    // we need to pass the buffer at the right byte offset for each token's data.
5558    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5559    // but our layout is data[token * row_stride + data_float_offset + ...].
5560    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5561    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5562    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5563    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5564    //
5565    // Simplest approach: use the single-token rope kernel for each token in a loop.
5566    // This is still efficient because we're dispatching all within the same command encoder.
5567    writeln!(
5568        code,
5569        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5570    )?;
5571    writeln!(code, "        for t in 0..num_tokens {{")?;
5572    writeln!(
5573        code,
5574        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5575    )?;
5576    writeln!(
5577        code,
5578        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5579    )?;
5580    writeln!(
5581        code,
5582        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5583    )?;
5584    writeln!(
5585        code,
5586        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5587    )?;
5588    writeln!(
5589        code,
5590        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5591    )?;
5592    writeln!(
5593        code,
5594        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5595    )?;
5596    writeln!(
5597        code,
5598        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5599    )?;
5600    writeln!(
5601        code,
5602        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5603    )?;
5604    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5605    writeln!(
5606        code,
5607        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5608    )?;
5609    writeln!(
5610        code,
5611        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5612    )?;
5613    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5614    writeln!(code, "        }}")?;
5615    writeln!(code, "    }}")?;
5616    writeln!(code)?;
5617
5618    // dispatch_silu_mul_fused_batch
5619    writeln!(
5620        code,
5621        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5622    )?;
5623    writeln!(
5624        code,
5625        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5626    )?;
5627    writeln!(code, "        let count: u32 = n as u32;")?;
5628    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5629    writeln!(
5630        code,
5631        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5632    )?;
5633    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5634    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5635    writeln!(
5636        code,
5637        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5638    )?;
5639    writeln!(
5640        code,
5641        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5642    )?;
5643    writeln!(code, "        let total = num_tokens * n;")?;
5644    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5645    writeln!(
5646        code,
5647        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5648    )?;
5649    writeln!(
5650        code,
5651        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5652    )?;
5653    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5654    writeln!(code, "    }}")?;
5655    writeln!(code)?;
5656
5657    // dispatch_add_inplace_batch_n (add n elements in-place)
5658    writeln!(
5659        code,
5660        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5661    )?;
5662    writeln!(
5663        code,
5664        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5665    )?;
5666    writeln!(code, "        let count: u32 = total_n as u32;")?;
5667    writeln!(
5668        code,
5669        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5670    )?;
5671    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5672    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5673    writeln!(
5674        code,
5675        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5676    )?;
5677    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5678    writeln!(
5679        code,
5680        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5681    )?;
5682    writeln!(
5683        code,
5684        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5685    )?;
5686    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5687    writeln!(code, "    }}")?;
5688    writeln!(code)?;
5689
5690    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
5691    writeln!(
5692        code,
5693        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
5694    )?;
5695    writeln!(
5696        code,
5697        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5698    )?;
5699    writeln!(code, "        let n: u32 = count as u32;")?;
5700    writeln!(
5701        code,
5702        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5703    )?;
5704    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5705    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5706    writeln!(
5707        code,
5708        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5709    )?;
5710    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5711    writeln!(
5712        code,
5713        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5714    )?;
5715    writeln!(
5716        code,
5717        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5718    )?;
5719    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5720    writeln!(code, "    }}")?;
5721    writeln!(code)?;
5722
5723    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
5724    writeln!(
5725        code,
5726        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
5727    )?;
5728    writeln!(
5729        code,
5730        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5731    )?;
5732    writeln!(code, "        let n: u32 = count as u32;")?;
5733    writeln!(
5734        code,
5735        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5736    )?;
5737    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5738    writeln!(
5739        code,
5740        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5741    )?;
5742    writeln!(
5743        code,
5744        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5745    )?;
5746    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5747    writeln!(
5748        code,
5749        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5750    )?;
5751    writeln!(
5752        code,
5753        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5754    )?;
5755    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5756    writeln!(code, "    }}")?;
5757    writeln!(code)?;
5758
5759    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
5760    writeln!(
5761        code,
5762        "    /// Copy from src at byte offset to dst at float offset."
5763    )?;
5764    writeln!(
5765        code,
5766        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5767    )?;
5768    writeln!(code, "        let n: u32 = count as u32;")?;
5769    writeln!(
5770        code,
5771        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5772    )?;
5773    writeln!(
5774        code,
5775        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5776    )?;
5777    writeln!(
5778        code,
5779        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5780    )?;
5781    writeln!(
5782        code,
5783        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5784    )?;
5785    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5786    writeln!(
5787        code,
5788        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5789    )?;
5790    writeln!(
5791        code,
5792        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5793    )?;
5794    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5795    writeln!(code, "    }}")?;
5796    writeln!(code)?;
5797
5798    // dispatch_copy_kv_batch
5799    writeln!(
5800        code,
5801        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
5802    )?;
5803    writeln!(
5804        code,
5805        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
5806    )?;
5807    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5808    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
5809    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5810    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
5811    writeln!(code, "        let so: u32 = src_offset as u32;")?;
5812    writeln!(
5813        code,
5814        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
5815    )?;
5816    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5817    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5818    writeln!(
5819        code,
5820        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5821    )?;
5822    writeln!(
5823        code,
5824        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
5825    )?;
5826    writeln!(
5827        code,
5828        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5829    )?;
5830    writeln!(
5831        code,
5832        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
5833    )?;
5834    writeln!(
5835        code,
5836        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
5837    )?;
5838    writeln!(code, "        let total = num_tokens * kv_dim;")?;
5839    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5840    writeln!(
5841        code,
5842        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5843    )?;
5844    writeln!(
5845        code,
5846        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5847    )?;
5848    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5849    writeln!(code, "    }}")?;
5850    writeln!(code)?;
5851
5852    // dispatch_attention_batch
5853    writeln!(
5854        code,
5855        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
5856    )?;
5857    writeln!(
5858        code,
5859        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
5860    )?;
5861    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5862    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5863    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5864    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5865    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5866    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
5867    // The legacy attention_batch kernel materializes a scores[2048] array
5868    // in threadgroup memory — fast at short seq_len but silently corrupts
5869    // when `base_pos + num_tokens > 2048`.  Switch to the flash kernel
5870    // (tiled online softmax, O(head_dim) threadgroup memory, no seq_len
5871    // cap) above that boundary.  Below it the legacy kernel is ~15% faster
5872    // due to the flash kernel's per-tile barrier overhead at low seq_len.
5873    writeln!(
5874        code,
5875        "        let max_seq = base_pos + num_tokens;"
5876    )?;
5877    // Temporarily: always use legacy — flash kernel has a bug discovered
5878    // after the chunking fix exposed multi-chunk paths at seq_len > 2000.
5879    writeln!(
5880        code,
5881        "        let _ = max_seq;"
5882    )?;
5883    writeln!(
5884        code,
5885        "        let pipe = &self.attention_batch_pipeline;"
5886    )?;
5887    writeln!(
5888        code,
5889        "        enc.set_compute_pipeline_state(pipe);"
5890    )?;
5891    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
5892    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5893    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5894    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5895    writeln!(
5896        code,
5897        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5898    )?;
5899    writeln!(
5900        code,
5901        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5902    )?;
5903    writeln!(
5904        code,
5905        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5906    )?;
5907    writeln!(
5908        code,
5909        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5910    )?;
5911    writeln!(
5912        code,
5913        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5914    )?;
5915    writeln!(
5916        code,
5917        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5918    )?;
5919    writeln!(
5920        code,
5921        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
5922    )?;
5923    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5924    writeln!(
5925        code,
5926        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
5927    )?;
5928    writeln!(
5929        code,
5930        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5931    )?;
5932    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5933    writeln!(code, "    }}")?;
5934    writeln!(code)?;
5935
5936    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
5937    writeln!(
5938        code,
5939        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
5940    )?;
5941    writeln!(
5942        code,
5943        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
5944    )?;
5945    writeln!(
5946        code,
5947        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
5948    )?;
5949    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5950    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5951    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
5952    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5953    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5954    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
5955    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5956    writeln!(
5957        code,
5958        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
5959    )?;
5960    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
5961    writeln!(
5962        code,
5963        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5964    )?;
5965    writeln!(
5966        code,
5967        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5968    )?;
5969    writeln!(
5970        code,
5971        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
5972    )?;
5973    writeln!(
5974        code,
5975        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5976    )?;
5977    writeln!(
5978        code,
5979        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5980    )?;
5981    writeln!(
5982        code,
5983        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5984    )?;
5985    writeln!(
5986        code,
5987        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5988    )?;
5989    writeln!(
5990        code,
5991        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
5992    )?;
5993    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5994    writeln!(
5995        code,
5996        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5997    )?;
5998    writeln!(
5999        code,
6000        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6001    )?;
6002    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6003    writeln!(code, "    }}")?;
6004    writeln!(code)?;
6005
6006    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
6007    writeln!(
6008        code,
6009        "    /// Dispatch fused K+V cache copy in one kernel launch."
6010    )?;
6011    writeln!(
6012        code,
6013        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6014    )?;
6015    writeln!(
6016        code,
6017        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6018    )?;
6019    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6020    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6021    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6022    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6023    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
6024    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
6025    writeln!(
6026        code,
6027        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6028    )?;
6029    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6030    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
6031    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
6032    writeln!(
6033        code,
6034        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6035    )?;
6036    writeln!(
6037        code,
6038        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6039    )?;
6040    writeln!(
6041        code,
6042        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6043    )?;
6044    writeln!(
6045        code,
6046        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6047    )?;
6048    writeln!(
6049        code,
6050        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6051    )?;
6052    writeln!(
6053        code,
6054        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6055    )?;
6056    writeln!(
6057        code,
6058        "        let total = num_tokens * kv_dim * 2;  // K + V"
6059    )?;
6060    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6061    writeln!(
6062        code,
6063        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6064    )?;
6065    writeln!(
6066        code,
6067        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6068    )?;
6069    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6070    writeln!(code, "    }}")?;
6071
6072    writeln!(code, "}}")?;
6073    writeln!(code)?;
6074
6075    Ok(())
6076}
6077
6078fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6079    writeln!(
6080        code,
6081        "// ── Helper functions ──────────────────────────────────"
6082    )?;
6083    writeln!(code)?;
6084    writeln!(
6085        code,
6086        "/// Create a compute pipeline from a named function in the Metal library."
6087    )?;
6088    writeln!(
6089        code,
6090        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6091    )?;
6092    writeln!(
6093        code,
6094        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6095    )?;
6096    writeln!(
6097        code,
6098        "    device.new_compute_pipeline_state_with_function(&func)"
6099    )?;
6100    writeln!(
6101        code,
6102        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6103    )?;
6104    writeln!(code, "}}")?;
6105    writeln!(code)?;
6106
6107    Ok(())
6108}
6109
6110// ---------------------------------------------------------------------------
6111// main.rs generation
6112// ---------------------------------------------------------------------------
6113
6114fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6115    let _sanitized = sanitize_name(model_name);
6116    let _vocab = config.vocab_size;
6117
6118    let mut code = String::with_capacity(16 * 1024);
6119    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6120    writeln!(
6121        code,
6122        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6123    )?;
6124    writeln!(code)?;
6125    writeln!(code, "mod model;")?;
6126    writeln!(code)?;
6127    writeln!(code, "use std::io::Write;")?;
6128    writeln!(code, "use std::time::Instant;")?;
6129    writeln!(code, "use serde::Deserialize;")?;
6130    writeln!(code)?;
6131
6132    // -- main function --
6133    writeln!(code, "fn main() {{")?;
6134    writeln!(
6135        code,
6136        "    let args: Vec<String> = std::env::args().collect();"
6137    )?;
6138    writeln!(code)?;
6139    writeln!(
6140        code,
6141        "    // Detect --serve mode (only requires weights + tokenizer)"
6142    )?;
6143    writeln!(
6144        code,
6145        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
6146    )?;
6147    writeln!(code)?;
6148    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
6149    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6150    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6151    writeln!(code, "        std::process::exit(1);")?;
6152    writeln!(code, "    }}")?;
6153    writeln!(code)?;
6154    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
6155    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6156    writeln!(code, "        std::process::exit(1);")?;
6157    writeln!(code, "    }}")?;
6158    writeln!(code)?;
6159    writeln!(code, "    let weights_path = &args[1];")?;
6160    writeln!(code, "    let tokenizer_path = &args[2];")?;
6161    writeln!(code)?;
6162    writeln!(code, "    // Parse optional flags")?;
6163    writeln!(code, "    let mut max_tokens: usize = 128;")?;
6164    writeln!(code, "    let mut port: u16 = 8080;")?;
6165    writeln!(
6166        code,
6167        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6168    )?;
6169    writeln!(
6170        code,
6171        "    let profile = args.iter().any(|a| a == \"--profile\");"
6172    )?;
6173    writeln!(code, "    let mut i = 3;")?;
6174    writeln!(code, "    while i < args.len() {{")?;
6175    writeln!(
6176        code,
6177        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6178    )?;
6179    writeln!(
6180        code,
6181        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
6182    )?;
6183    writeln!(code, "            i += 2;")?;
6184    writeln!(
6185        code,
6186        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6187    )?;
6188    writeln!(
6189        code,
6190        "            port = args[i + 1].parse().unwrap_or(8080);"
6191    )?;
6192    writeln!(code, "            i += 2;")?;
6193    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
6194    writeln!(code, "            i += 1;")?;
6195    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
6196    writeln!(code, "            i += 1;")?;
6197    writeln!(code, "        }} else {{")?;
6198    writeln!(code, "            i += 1;")?;
6199    writeln!(code, "        }}")?;
6200    writeln!(code, "    }}")?;
6201    writeln!(code)?;
6202
6203    // -- load model (shared by both modes) --
6204    writeln!(
6205        code,
6206        "    // Memory-map weights for zero-copy loading on Apple Silicon"
6207    )?;
6208    writeln!(
6209        code,
6210        "    let weights_file = std::fs::File::open(weights_path)"
6211    )?;
6212    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6213    writeln!(
6214        code,
6215        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6216    )?;
6217    writeln!(code)?;
6218    writeln!(code, "    // Load tokenizer")?;
6219    writeln!(
6220        code,
6221        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6222    )?;
6223    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6224    writeln!(code)?;
6225    writeln!(code, "    // Create Metal model")?;
6226    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
6227    writeln!(
6228        code,
6229        "    let mut model = model::MetalModel::new(&weights_mmap);"
6230    )?;
6231    writeln!(code)?;
6232
6233    // -- branch: serve vs CLI --
6234    writeln!(code, "    if serve_mode {{")?;
6235    writeln!(code, "        serve(model, tokenizer, port);")?;
6236    writeln!(code, "    }} else {{")?;
6237    writeln!(code, "        let prompt = &args[3];")?;
6238    writeln!(
6239        code,
6240        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6241    )?;
6242    writeln!(code, "    }}")?;
6243    writeln!(code, "}}")?;
6244    writeln!(code)?;
6245
6246    // -- cli_mode function --
6247    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6248    writeln!(code, "    // Tokenize prompt")?;
6249    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
6250    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6251    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6252    writeln!(code)?;
6253    writeln!(
6254        code,
6255        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6256    )?;
6257    writeln!(
6258        code,
6259        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
6260    )?;
6261    writeln!(
6262        code,
6263        "    // The last token uses synchronous forward() to get logits."
6264    )?;
6265    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
6266    writeln!(code, "    let prefill_start = Instant::now();")?;
6267    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
6268    writeln!(
6269        code,
6270        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6271    )?;
6272    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
6273    writeln!(code, "    }} else {{")?;
6274    writeln!(code, "        model.forward(prompt_tokens[0])")?;
6275    writeln!(code, "    }};")?;
6276    writeln!(
6277        code,
6278        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6279    )?;
6280    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
6281    writeln!(
6282        code,
6283        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6284    )?;
6285    writeln!(
6286        code,
6287        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6288    )?;
6289    writeln!(code)?;
6290    writeln!(code, "    // Generate tokens")?;
6291    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6292    writeln!(code, "    let gen_start = Instant::now();")?;
6293    writeln!(code, "    let mut generated_count: usize = 0;")?;
6294    writeln!(code)?;
6295    writeln!(code, "    for _ in 0..max_tokens {{")?;
6296    writeln!(
6297        code,
6298        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6299    )?;
6300    writeln!(code, "            if !quiet {{")?;
6301    writeln!(code, "                print!(\"{{}}\", text);")?;
6302    writeln!(code, "                std::io::stdout().flush().ok();")?;
6303    writeln!(code, "            }}")?;
6304    writeln!(code, "        }}")?;
6305    writeln!(code, "        generated_count += 1;")?;
6306    writeln!(code)?;
6307    writeln!(
6308        code,
6309        "        // Use profiling forward for first token when --profile is set"
6310    )?;
6311    writeln!(
6312        code,
6313        "        let logits = if profile && generated_count == 1 {{"
6314    )?;
6315    writeln!(code, "            model.forward_profile(next_token)")?;
6316    writeln!(code, "        }} else {{")?;
6317    writeln!(code, "            model.forward(next_token)")?;
6318    writeln!(code, "        }};")?;
6319    writeln!(code, "        next_token = argmax(&logits);")?;
6320    writeln!(code)?;
6321    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
6322    writeln!(code, "        if next_token == 2 {{")?;
6323    writeln!(code, "            break;")?;
6324    writeln!(code, "        }}")?;
6325    writeln!(code)?;
6326    writeln!(
6327        code,
6328        "        // Yield between tokens to reduce sustained GPU thermal load."
6329    )?;
6330    writeln!(
6331        code,
6332        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6333    )?;
6334    writeln!(
6335        code,
6336        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6337    )?;
6338    writeln!(
6339        code,
6340        "        // briefly, providing a micro-break that helps sustain peak throughput."
6341    )?;
6342    writeln!(code, "        std::thread::yield_now();")?;
6343    writeln!(code, "    }}")?;
6344    writeln!(code, "    if !quiet {{")?;
6345    writeln!(code, "        println!();")?;
6346    writeln!(code, "    }}")?;
6347    writeln!(
6348        code,
6349        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6350    )?;
6351    writeln!(
6352        code,
6353        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6354    )?;
6355    writeln!(
6356        code,
6357        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6358    )?;
6359    writeln!(code, "}}")?;
6360    writeln!(code)?;
6361
6362    // -- argmax helper --
6363    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6364    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6365    writeln!(code, "    logits.iter()")?;
6366    writeln!(code, "        .enumerate()")?;
6367    writeln!(
6368        code,
6369        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6370    )?;
6371    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6372    writeln!(code, "        .unwrap_or(0)")?;
6373    writeln!(code, "}}")?;
6374    writeln!(code)?;
6375
6376    // -- Request/Response types for OpenAI API --
6377    writeln!(
6378        code,
6379        "// -----------------------------------------------------------------------"
6380    )?;
6381    writeln!(code, "// OpenAI-compatible API server")?;
6382    writeln!(
6383        code,
6384        "// -----------------------------------------------------------------------"
6385    )?;
6386    writeln!(code)?;
6387    writeln!(code, "#[derive(Deserialize)]")?;
6388    writeln!(code, "struct ChatRequest {{")?;
6389    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6390    writeln!(code, "    #[serde(default)]")?;
6391    writeln!(code, "    stream: Option<bool>,")?;
6392    writeln!(code, "    #[serde(default)]")?;
6393    writeln!(code, "    max_tokens: Option<usize>,")?;
6394    writeln!(code, "    #[serde(default)]")?;
6395    writeln!(code, "    temperature: Option<f32>,")?;
6396    writeln!(code, "    #[serde(default)]")?;
6397    writeln!(code, "    model: Option<String>,")?;
6398    writeln!(code, "}}")?;
6399    writeln!(code)?;
6400    writeln!(code, "#[derive(Deserialize)]")?;
6401    writeln!(code, "struct ChatMessage {{")?;
6402    writeln!(code, "    role: String,")?;
6403    writeln!(code, "    content: String,")?;
6404    writeln!(code, "}}")?;
6405    writeln!(code)?;
6406
6407    // -- format_chat_messages --
6408    writeln!(
6409        code,
6410        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6411    )?;
6412    writeln!(code, "    let mut prompt = String::new();")?;
6413    writeln!(code, "    for msg in messages {{")?;
6414    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6415    writeln!(code, "    }}")?;
6416    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6417    writeln!(code, "    prompt")?;
6418    writeln!(code, "}}")?;
6419    writeln!(code)?;
6420
6421    // -- prefill helper --
6422    writeln!(
6423        code,
6424        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6425    )?;
6426    writeln!(code, "    let len = tokens.len();")?;
6427    writeln!(code, "    if len > 1 {{")?;
6428    writeln!(
6429        code,
6430        "        model.forward_prefill_batch(&tokens[..len - 1]);"
6431    )?;
6432    writeln!(code, "    }}")?;
6433    writeln!(code, "    model.forward(tokens[len - 1])")?;
6434    writeln!(code, "}}")?;
6435    writeln!(code)?;
6436
6437    // -- serve function --
6438    writeln!(
6439        code,
6440        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6441    )?;
6442    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6443    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
6444    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6445    writeln!(
6446        code,
6447        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6448    )?;
6449    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
6450    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
6451    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
6452    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
6453    writeln!(code)?;
6454    writeln!(code, "    for request in server.incoming_requests() {{")?;
6455    writeln!(code, "        let method = request.method().to_string();")?;
6456    writeln!(code, "        let url = request.url().to_string();")?;
6457    writeln!(code)?;
6458    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
6459
6460    // -- POST /v1/chat/completions --
6461    writeln!(
6462        code,
6463        "            (\"POST\", \"/v1/chat/completions\") => {{"
6464    )?;
6465    writeln!(
6466        code,
6467        "                handle_chat_completion(&mut model, &tokenizer, request);"
6468    )?;
6469    writeln!(code, "            }}")?;
6470
6471    // -- GET /v1/models --
6472    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
6473    writeln!(code, "                let body = serde_json::json!({{")?;
6474    writeln!(code, "                    \"object\": \"list\",")?;
6475    writeln!(code, "                    \"data\": [{{")?;
6476    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
6477    writeln!(code, "                        \"object\": \"model\",")?;
6478    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
6479    writeln!(code, "                    }}]")?;
6480    writeln!(code, "                }});")?;
6481    writeln!(
6482        code,
6483        "                let resp = tiny_http::Response::from_string(body.to_string())"
6484    )?;
6485    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6486    writeln!(code, "                request.respond(resp).ok();")?;
6487    writeln!(code, "            }}")?;
6488
6489    // -- GET /health --
6490    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
6491    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6492    writeln!(code, "                request.respond(resp).ok();")?;
6493    writeln!(code, "            }}")?;
6494
6495    // -- 404 --
6496    writeln!(code, "            _ => {{")?;
6497    writeln!(
6498        code,
6499        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
6500    )?;
6501    writeln!(code, "                    .with_status_code(404);")?;
6502    writeln!(code, "                request.respond(resp).ok();")?;
6503    writeln!(code, "            }}")?;
6504    writeln!(code, "        }}")?;
6505    writeln!(code, "    }}")?;
6506    writeln!(code, "}}")?;
6507    writeln!(code)?;
6508
6509    // -- handle_chat_completion --
6510    writeln!(code, "fn handle_chat_completion(")?;
6511    writeln!(code, "    model: &mut model::MetalModel,")?;
6512    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6513    writeln!(code, "    mut request: tiny_http::Request,")?;
6514    writeln!(code, ") {{")?;
6515    writeln!(code, "    // Read request body")?;
6516    writeln!(code, "    let mut body = String::new();")?;
6517    writeln!(
6518        code,
6519        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6520    )?;
6521    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6522    writeln!(code, "            .with_status_code(400);")?;
6523    writeln!(code, "        request.respond(resp).ok();")?;
6524    writeln!(code, "        return;")?;
6525    writeln!(code, "    }}")?;
6526    writeln!(code)?;
6527    writeln!(code, "    // Parse JSON")?;
6528    writeln!(
6529        code,
6530        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6531    )?;
6532    writeln!(code, "        Ok(r) => r,")?;
6533    writeln!(code, "        Err(e) => {{")?;
6534    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6535    writeln!(code, "                .with_status_code(400);")?;
6536    writeln!(code, "            request.respond(resp).ok();")?;
6537    writeln!(code, "            return;")?;
6538    writeln!(code, "        }}")?;
6539    writeln!(code, "    }};")?;
6540    writeln!(code)?;
6541    writeln!(
6542        code,
6543        "    let prompt = format_chat_messages(&req.messages);"
6544    )?;
6545    writeln!(
6546        code,
6547        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6548    )?;
6549    writeln!(code, "        Ok(e) => e,")?;
6550    writeln!(code, "        Err(e) => {{")?;
6551    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6552    writeln!(code, "                .with_status_code(500);")?;
6553    writeln!(code, "            request.respond(resp).ok();")?;
6554    writeln!(code, "            return;")?;
6555    writeln!(code, "        }}")?;
6556    writeln!(code, "    }};")?;
6557    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6558    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6559    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6560    writeln!(
6561        code,
6562        "    let _temperature = req.temperature.unwrap_or(1.0);"
6563    )?;
6564    writeln!(code)?;
6565
6566    // -- Reset KV cache for each request --
6567    writeln!(code, "    model.reset();")?;
6568    writeln!(code)?;
6569
6570    // -- Prefill with timing --
6571    writeln!(code, "    let prefill_start = Instant::now();")?;
6572    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6573    writeln!(
6574        code,
6575        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6576    )?;
6577    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6578    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6579    writeln!(code)?;
6580
6581    writeln!(code, "    if stream {{")?;
6582
6583    // -- SSE streaming response --
6584    writeln!(
6585        code,
6586        "        // SSE streaming: generate tokens and build SSE body"
6587    )?;
6588    writeln!(code, "        let gen_start = Instant::now();")?;
6589    writeln!(code, "        let mut generated_count: usize = 0;")?;
6590    writeln!(code, "        let mut sse_body = String::new();")?;
6591    writeln!(code, "        for _ in 0..max_tokens {{")?;
6592    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6593    writeln!(
6594        code,
6595        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6596    )?;
6597    writeln!(
6598        code,
6599        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6600    )?;
6601    writeln!(
6602        code,
6603        "                // escaped includes surrounding quotes, strip them"
6604    )?;
6605    writeln!(
6606        code,
6607        "                let inner = &escaped[1..escaped.len()-1];"
6608    )?;
6609    writeln!(code, "                sse_body.push_str(&format!(")?;
6610    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6611    writeln!(code, "                    inner")?;
6612    writeln!(code, "                ));")?;
6613    writeln!(code, "            }}")?;
6614    writeln!(code, "            generated_count += 1;")?;
6615    writeln!(code, "            let logits = model.forward(next_token);")?;
6616    writeln!(code, "            next_token = argmax(&logits);")?;
6617    writeln!(code, "        }}")?;
6618    writeln!(
6619        code,
6620        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6621    )?;
6622    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6623    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
6624    writeln!(code)?;
6625    writeln!(
6626        code,
6627        "        // Final chunk with finish_reason, timing, and DONE sentinel"
6628    )?;
6629    writeln!(code, "        sse_body.push_str(&format!(")?;
6630    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6631    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6632    writeln!(code, "        ));")?;
6633    writeln!(code)?;
6634    writeln!(
6635        code,
6636        "        let resp = tiny_http::Response::from_string(sse_body)"
6637    )?;
6638    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6639    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6640    writeln!(code, "        request.respond(resp).ok();")?;
6641
6642    writeln!(code, "    }} else {{")?;
6643
6644    // -- Non-streaming response --
6645    writeln!(
6646        code,
6647        "        // Non-streaming: generate all tokens, return JSON"
6648    )?;
6649    writeln!(code, "        let gen_start = Instant::now();")?;
6650    writeln!(code, "        let mut generated_count: usize = 0;")?;
6651    writeln!(code, "        let mut generated = String::new();")?;
6652    writeln!(code, "        for _ in 0..max_tokens {{")?;
6653    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6654    writeln!(
6655        code,
6656        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6657    )?;
6658    writeln!(code, "                generated.push_str(&text);")?;
6659    writeln!(code, "            }}")?;
6660    writeln!(code, "            generated_count += 1;")?;
6661    writeln!(code, "            let logits = model.forward(next_token);")?;
6662    writeln!(code, "            next_token = argmax(&logits);")?;
6663    writeln!(code, "        }}")?;
6664    writeln!(
6665        code,
6666        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6667    )?;
6668    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6669    writeln!(code)?;
6670    writeln!(code, "        let resp_json = serde_json::json!({{")?;
6671    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
6672    writeln!(code, "            \"object\": \"chat.completion\",")?;
6673    writeln!(code, "            \"choices\": [{{")?;
6674    writeln!(code, "                \"index\": 0,")?;
6675    writeln!(code, "                \"message\": {{")?;
6676    writeln!(code, "                    \"role\": \"assistant\",")?;
6677    writeln!(code, "                    \"content\": generated")?;
6678    writeln!(code, "                }},")?;
6679    writeln!(code, "                \"finish_reason\": \"stop\"")?;
6680    writeln!(code, "            }}],")?;
6681    writeln!(code, "            \"usage\": {{")?;
6682    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
6683    writeln!(
6684        code,
6685        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
6686    )?;
6687    writeln!(
6688        code,
6689        "                \"generation_tokens\": generated_count,"
6690    )?;
6691    writeln!(
6692        code,
6693        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
6694    )?;
6695    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
6696    writeln!(code, "            }}")?;
6697    writeln!(code, "        }});")?;
6698    writeln!(
6699        code,
6700        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
6701    )?;
6702    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6703    writeln!(code, "        request.respond(resp).ok();")?;
6704    writeln!(code, "    }}")?;
6705    writeln!(code, "}}")?;
6706
6707    Ok(code)
6708}
6709
6710// ---------------------------------------------------------------------------
6711// Tests
6712// ---------------------------------------------------------------------------
6713
6714#[cfg(test)]
6715mod tests {
6716    use super::*;
6717    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
6718
6719    fn minimal_config() -> ModelConfig {
6720        ModelConfig {
6721            architecture: Architecture::Llama,
6722            hidden_size: 64,
6723            intermediate_size: 128,
6724            num_layers: 2,
6725            num_attention_heads: 4,
6726            num_kv_heads: 4,
6727            head_dim: 16,
6728            vocab_size: 256,
6729            max_seq_len: 512,
6730            rms_norm_eps: 1e-5,
6731            rope_theta: 10000.0,
6732            dtype: DType::F32,
6733            sliding_window_size: None,
6734            qkv_bias: false,
6735        }
6736    }
6737
6738    fn minimal_graph() -> Graph {
6739        Graph::new("test-metal").with_config(minimal_config())
6740    }
6741
6742    #[test]
6743    fn generate_metal_project_creates_files() {
6744        let dir = tempfile::tempdir().unwrap();
6745        let graph = minimal_graph();
6746        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
6747
6748        assert!(
6749            dir.path().join("Cargo.toml").exists(),
6750            "Cargo.toml should be created"
6751        );
6752        assert!(
6753            dir.path().join("src/model.rs").exists(),
6754            "src/model.rs should be created"
6755        );
6756        assert!(
6757            dir.path().join("src/main.rs").exists(),
6758            "src/main.rs should be created"
6759        );
6760        assert!(
6761            dir.path().join("shaders/kernels.metal").exists(),
6762            "shaders/kernels.metal should be created"
6763        );
6764    }
6765
6766    #[test]
6767    fn generated_cargo_toml_has_metal_dep() {
6768        let toml = generate_cargo_toml("my-model");
6769        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
6770        assert!(
6771            toml.contains("tokenizers"),
6772            "Cargo.toml should depend on tokenizers"
6773        );
6774        assert!(
6775            toml.contains("memmap2"),
6776            "Cargo.toml should depend on memmap2"
6777        );
6778        assert!(toml.contains("half"), "Cargo.toml should depend on half");
6779    }
6780
6781    #[test]
6782    fn generated_model_rs_contains_metal_code() {
6783        let config = minimal_config();
6784        let model_rs = generate_model_rs(&config).unwrap();
6785
6786        assert!(
6787            model_rs.contains("pub struct MetalModel"),
6788            "model.rs should define MetalModel struct"
6789        );
6790        assert!(
6791            model_rs.contains("matmul_pipeline: ComputePipelineState"),
6792            "MetalModel should have matmul_pipeline field"
6793        );
6794        assert!(
6795            model_rs.contains("Device::system_default()"),
6796            "model.rs should use Metal device"
6797        );
6798        assert!(
6799            model_rs.contains("new_library_with_source"),
6800            "model.rs should compile Metal shaders"
6801        );
6802        assert!(
6803            model_rs.contains("fn new(weights: &[u8])"),
6804            "MetalModel should implement new()"
6805        );
6806        assert!(
6807            model_rs.contains("fn forward(&mut self, token_id: u32)"),
6808            "MetalModel should implement forward()"
6809        );
6810    }
6811
6812    #[test]
6813    fn generated_shaders_contain_kernel_names() {
6814        let shaders = generate_metal_shaders(&minimal_config());
6815
6816        assert!(
6817            shaders.contains("kernel void matmul_vec"),
6818            "shaders should contain matmul_vec kernel"
6819        );
6820        assert!(
6821            shaders.contains("kernel void rms_norm"),
6822            "shaders should contain rms_norm kernel"
6823        );
6824        assert!(
6825            shaders.contains("kernel void rope"),
6826            "shaders should contain rope kernel"
6827        );
6828        assert!(
6829            shaders.contains("kernel void softmax"),
6830            "shaders should contain softmax kernel"
6831        );
6832        assert!(
6833            shaders.contains("kernel void silu_mul("),
6834            "shaders should contain silu_mul kernel"
6835        );
6836        assert!(
6837            shaders.contains("kernel void silu_mul_fused"),
6838            "shaders should contain silu_mul_fused kernel"
6839        );
6840        assert!(
6841            shaders.contains("kernel void elementwise_add"),
6842            "shaders should contain elementwise_add kernel"
6843        );
6844        assert!(
6845            shaders.contains("kernel void attention"),
6846            "shaders should contain attention kernel"
6847        );
6848        assert!(
6849            shaders.contains("kernel void add_inplace"),
6850            "shaders should contain add_inplace kernel"
6851        );
6852        assert!(
6853            shaders.contains("kernel void copy_buffer"),
6854            "shaders should contain copy_buffer kernel"
6855        );
6856        assert!(
6857            shaders.contains("kernel void copy_offset"),
6858            "shaders should contain copy_offset kernel"
6859        );
6860    }
6861
6862    #[test]
6863    fn generated_shaders_use_simdgroup_features() {
6864        let shaders = generate_metal_shaders(&minimal_config());
6865
6866        assert!(
6867            shaders.contains("threadgroup_barrier"),
6868            "shaders should use threadgroup barriers"
6869        );
6870        assert!(
6871            shaders.contains("threadgroup float"),
6872            "shaders should use threadgroup shared memory"
6873        );
6874        assert!(
6875            shaders.contains("thread_index_in_threadgroup"),
6876            "shaders should use threadgroup indexing"
6877        );
6878        assert!(
6879            shaders.contains("simd_sum"),
6880            "shaders should use simd_sum for warp-level reduction"
6881        );
6882        assert!(
6883            shaders.contains("simd_max"),
6884            "attention kernel should use simd_max for cooperative softmax"
6885        );
6886        assert!(
6887            shaders.contains("thread_index_in_simdgroup"),
6888            "shaders should use simdgroup lane indexing"
6889        );
6890        assert!(
6891            shaders.contains("simdgroup_index_in_threadgroup"),
6892            "shaders should use simdgroup indexing within threadgroup"
6893        );
6894        assert!(
6895            shaders.contains("float4"),
6896            "matmul_vec should use float4 vectorized loads"
6897        );
6898    }
6899
6900    #[test]
6901    fn generated_main_rs_has_tokenizer_usage() {
6902        let config = minimal_config();
6903        let main_rs = generate_main_rs("test-model", &config).unwrap();
6904
6905        assert!(
6906            main_rs.contains("tokenizers::Tokenizer"),
6907            "main.rs should use tokenizers crate"
6908        );
6909        assert!(
6910            main_rs.contains("MetalModel::new"),
6911            "main.rs should call MetalModel::new"
6912        );
6913        assert!(
6914            main_rs.contains("model.forward"),
6915            "main.rs should call model.forward"
6916        );
6917        assert!(
6918            main_rs.contains("memmap2"),
6919            "main.rs should use memmap2 for zero-copy weight loading"
6920        );
6921    }
6922
6923    #[test]
6924    fn missing_config_returns_error() {
6925        let dir = tempfile::tempdir().unwrap();
6926        let graph = Graph::new("no-config");
6927        let result = generate_metal_project(&graph, dir.path(), "fail");
6928        assert!(
6929            matches!(result, Err(MetalCodegenError::MissingConfig)),
6930            "should fail with MissingConfig when graph has no config"
6931        );
6932    }
6933
6934    #[test]
6935    fn sanitize_name_works() {
6936        assert_eq!(sanitize_name("My Model!"), "my-model");
6937        assert_eq!(sanitize_name("test_model"), "test-model");
6938        assert_eq!(sanitize_name("simple"), "simple");
6939    }
6940
6941    #[test]
6942    fn generated_forward_uses_single_command_buffer() {
6943        let config = minimal_config();
6944        let model_rs = generate_model_rs(&config).unwrap();
6945
6946        // The forward function should create exactly one command buffer.
6947        // Use the exact signature to avoid matching forward_prefill/forward_profile.
6948        let forward_start = model_rs
6949            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6950            .unwrap();
6951        let forward_body = &model_rs[forward_start..];
6952        // End at the next pub/private method
6953        let forward_end = forward_body
6954            .find("\n    pub fn forward_profile")
6955            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
6956            .or_else(|| forward_body.find("\n    fn dispatch_"))
6957            .unwrap_or(forward_body.len());
6958        let forward_code = &forward_body[..forward_end];
6959
6960        // Should have exactly one new_command_buffer call
6961        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
6962        assert_eq!(
6963            cmd_buf_count, 1,
6964            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
6965        );
6966
6967        // Should have exactly one commit call
6968        let commit_count = forward_code.matches("cmd.commit()").count();
6969        assert_eq!(
6970            commit_count, 1,
6971            "forward() should commit exactly once, found {commit_count}"
6972        );
6973
6974        // Should wait: once for cmd + possibly once for prev_cmd drain
6975        let wait_count = forward_code.matches("wait_until_completed()").count();
6976        assert!(
6977            wait_count >= 1 && wait_count <= 2,
6978            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
6979        );
6980    }
6981
6982    #[test]
6983    fn generated_model_has_preallocated_working_buffers() {
6984        let config = minimal_config();
6985        let model_rs = generate_model_rs(&config).unwrap();
6986
6987        for buf_name in &[
6988            "normed_buf",
6989            "qkv_buf",
6990            "attn_out_buf",
6991            "attn_proj_buf",
6992            "gate_up_buf",
6993            "ffn_hidden_buf",
6994            "ffn_out_buf",
6995            "add_tmp_buf",
6996        ] {
6997            assert!(
6998                model_rs.contains(&format!("{buf_name}: Buffer")),
6999                "MetalModel should have pre-allocated {buf_name} field"
7000            );
7001        }
7002    }
7003
7004    #[test]
7005    fn generated_dispatch_helpers_take_compute_encoder_ref() {
7006        let config = minimal_config();
7007        let model_rs = generate_model_rs(&config).unwrap();
7008
7009        for method in &[
7010            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7011            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7012            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7013            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7014            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7015            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7016            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7017            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7018            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7019            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7020            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7021            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7022            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7023        ] {
7024            assert!(
7025                model_rs.contains(method),
7026                "model.rs should contain dispatch helper: {method}"
7027            );
7028        }
7029    }
7030
7031    #[test]
7032    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7033        let config = minimal_config();
7034        let model_rs = generate_model_rs(&config).unwrap();
7035
7036        // Find dispatch helpers section and check none create their own encoders
7037        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7038        let helpers_code = &model_rs[helpers_start..];
7039
7040        // None of the dispatch_ helpers should call new_command_buffer
7041        assert!(
7042            !helpers_code.contains("self.queue.new_command_buffer()"),
7043            "dispatch helpers should not create their own command buffers"
7044        );
7045
7046        // None should create their own compute encoders
7047        assert!(
7048            !helpers_code.contains("new_compute_command_encoder()"),
7049            "dispatch helpers should not create their own compute encoders"
7050        );
7051
7052        // None should call end_encoding
7053        assert!(
7054            !helpers_code.contains("end_encoding()"),
7055            "dispatch helpers should not call end_encoding"
7056        );
7057
7058        // None should call commit or wait
7059        assert!(
7060            !helpers_code.contains(".commit()"),
7061            "dispatch helpers should not commit command buffers"
7062        );
7063        assert!(
7064            !helpers_code.contains("wait_until_completed"),
7065            "dispatch helpers should not wait on command buffers"
7066        );
7067    }
7068
7069    #[test]
7070    fn generated_forward_batches_compute_encoders() {
7071        let config = minimal_config();
7072        let model_rs = generate_model_rs(&config).unwrap();
7073
7074        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
7075        let forward_start = model_rs
7076            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7077            .unwrap();
7078        let forward_body = &model_rs[forward_start..];
7079        let forward_end = forward_body
7080            .find("\n    pub fn forward_profile")
7081            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7082            .or_else(|| forward_body.find("\n    fn dispatch_"))
7083            .unwrap_or(forward_body.len());
7084        let forward_code = &forward_body[..forward_end];
7085
7086        // Forward should not allocate new buffers
7087        assert!(
7088            !forward_code.contains("device.new_buffer"),
7089            "forward() should not allocate new buffers per call"
7090        );
7091
7092        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
7093        // Copy operations use compute copy kernels instead of blit encoders.
7094        let compute_encoder_count = forward_code
7095            .matches("new_compute_command_encoder()")
7096            .count();
7097        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7098
7099        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
7100        assert_eq!(
7101            compute_encoder_count, 1,
7102            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7103        );
7104        assert_eq!(
7105            blit_encoder_count, 0,
7106            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7107        );
7108    }
7109
7110    #[test]
7111    fn generated_forward_uses_add_inplace() {
7112        let config = minimal_config();
7113        let model_rs = generate_model_rs(&config).unwrap();
7114
7115        // Should use in-place add (no blit copy-back needed)
7116        assert!(
7117            model_rs.contains("dispatch_add_inplace"),
7118            "forward() should use dispatch_add_inplace for residual connections"
7119        );
7120        assert!(
7121            model_rs.contains("add_inplace_pipeline"),
7122            "MetalModel should have add_inplace_pipeline"
7123        );
7124    }
7125
7126    fn minimal_q8_config() -> ModelConfig {
7127        ModelConfig {
7128            architecture: Architecture::Llama,
7129            hidden_size: 64,
7130            intermediate_size: 128,
7131            num_layers: 2,
7132            num_attention_heads: 4,
7133            num_kv_heads: 4,
7134            head_dim: 16,
7135            vocab_size: 256,
7136            max_seq_len: 512,
7137            rms_norm_eps: 1e-5,
7138            rope_theta: 10000.0,
7139            dtype: DType::Q8_0,
7140            sliding_window_size: None,
7141            qkv_bias: false,
7142        }
7143    }
7144
7145    #[test]
7146    fn generated_shaders_contain_q8_kernel() {
7147        let shaders = generate_metal_shaders(&minimal_config());
7148
7149        assert!(
7150            shaders.contains("kernel void matmul_vec_q8"),
7151            "shaders should contain matmul_vec_q8 kernel"
7152        );
7153        assert!(
7154            shaders.contains("device const uchar* matrix"),
7155            "matmul_vec_q8 should accept raw Q8_0 bytes"
7156        );
7157        assert!(
7158            shaders.contains("packed_short4"),
7159            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7160        );
7161        assert!(
7162            shaders.contains("as_type<char2>"),
7163            "matmul_vec_q8 should bitcast short lanes to char2"
7164        );
7165        assert!(
7166            shaders.contains("device const half*"),
7167            "matmul_vec_q8 should read f16 scale via half pointer"
7168        );
7169    }
7170
7171    #[test]
7172    fn generated_model_uses_fused_qkv_projections() {
7173        let config = minimal_config();
7174        let model_rs = generate_model_rs(&config).unwrap();
7175
7176        // Should have fused QKV weight in layer buffers
7177        assert!(
7178            model_rs.contains("qkv_weight: Buffer"),
7179            "LayerBuffers should have fused qkv_weight field"
7180        );
7181        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
7182        assert!(
7183            !model_rs.contains("    q_weight: Buffer"),
7184            "LayerBuffers should not have separate q_weight field"
7185        );
7186        assert!(
7187            !model_rs.contains("    k_weight: Buffer"),
7188            "LayerBuffers should not have separate k_weight field"
7189        );
7190        assert!(
7191            !model_rs.contains("    v_weight: Buffer"),
7192            "LayerBuffers should not have separate v_weight field"
7193        );
7194
7195        // Should have fused gate_up_weight
7196        assert!(
7197            model_rs.contains("gate_up_weight: Buffer"),
7198            "LayerBuffers should have fused gate_up_weight field"
7199        );
7200        // Should NOT have separate gate/up weight fields
7201        assert!(
7202            !model_rs.contains("    gate_weight: Buffer"),
7203            "LayerBuffers should not have separate gate_weight field"
7204        );
7205        assert!(
7206            !model_rs.contains("    up_weight: Buffer"),
7207            "LayerBuffers should not have separate up_weight field"
7208        );
7209
7210        // Should have fused working buffers
7211        assert!(
7212            model_rs.contains("qkv_buf: Buffer"),
7213            "MetalModel should have fused qkv_buf"
7214        );
7215        assert!(
7216            model_rs.contains("gate_up_buf: Buffer"),
7217            "MetalModel should have fused gate_up_buf"
7218        );
7219
7220        // Forward pass should use fused dispatch
7221        assert!(
7222            model_rs.contains("dispatch_silu_mul_fused"),
7223            "forward pass should use dispatch_silu_mul_fused"
7224        );
7225        assert!(
7226            model_rs.contains("dispatch_rope_offset"),
7227            "forward pass should use dispatch_rope_offset for fused QKV"
7228        );
7229        assert!(
7230            model_rs.contains("dispatch_attention_offset"),
7231            "forward pass should use dispatch_attention_offset for fused QKV"
7232        );
7233    }
7234
7235    #[test]
7236    fn q8_model_has_matmul_q8_pipeline() {
7237        let config = minimal_q8_config();
7238        let model_rs = generate_model_rs(&config).unwrap();
7239
7240        assert!(
7241            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7242            "MetalModel should have matmul_q8_pipeline field"
7243        );
7244        assert!(
7245            model_rs.contains("matmul_q8_pipeline,"),
7246            "MetalModel Self should include matmul_q8_pipeline"
7247        );
7248    }
7249
7250    #[test]
7251    fn q8_model_uses_dispatch_matmul_q8() {
7252        let config = minimal_q8_config();
7253        let model_rs = generate_model_rs(&config).unwrap();
7254
7255        assert!(
7256            model_rs.contains("dispatch_matmul_q8"),
7257            "Q8_0 model should use dispatch_matmul_q8 for projections"
7258        );
7259        assert!(
7260            model_rs.contains("fn dispatch_matmul_q8"),
7261            "model.rs should define dispatch_matmul_q8 method"
7262        );
7263    }
7264
7265    #[test]
7266    fn q8_model_loads_raw_bytes_not_dequantized() {
7267        let config = minimal_q8_config();
7268        let model_rs = generate_model_rs(&config).unwrap();
7269
7270        // Should NOT contain dequantization code
7271        assert!(
7272            !model_rs.contains("f16_to_f32"),
7273            "Q8_0 model should not dequantize weights to f32"
7274        );
7275        assert!(
7276            !model_rs.contains("f32_data"),
7277            "Q8_0 model should not create f32 weight data"
7278        );
7279
7280        // Should load raw Q8_0 bytes directly
7281        assert!(
7282            model_rs.contains("total_raw as u64"),
7283            "Q8_0 model should load raw bytes into Metal buffer"
7284        );
7285    }
7286
7287    #[test]
7288    fn q8_model_norms_stay_f32() {
7289        let config = minimal_q8_config();
7290        let model_rs = generate_model_rs(&config).unwrap();
7291
7292        // Norm weights should still use f32 buffers
7293        assert!(
7294            model_rs.contains("let attn_norm = next_f32_buffer"),
7295            "attn_norm should use f32 buffer even for Q8_0 models"
7296        );
7297        assert!(
7298            model_rs.contains("let ffn_norm = next_f32_buffer"),
7299            "ffn_norm should use f32 buffer even for Q8_0 models"
7300        );
7301        assert!(
7302            model_rs.contains("let norm_buf = next_f32_buffer"),
7303            "final norm should use f32 buffer even for Q8_0 models"
7304        );
7305    }
7306
7307    #[test]
7308    fn q8_model_uses_fused_weight_loading() {
7309        let config = minimal_q8_config();
7310        let model_rs = generate_model_rs(&config).unwrap();
7311
7312        // Should use fused Q8 buffer loading for QKV
7313        assert!(
7314            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7315            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7316        );
7317        // Should use fused Q8 buffer loading for gate+up
7318        assert!(
7319            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7320            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7321        );
7322        // Should still use regular q8 buffer for individual weights
7323        assert!(
7324            model_rs.contains("let o_weight = next_q8_buffer"),
7325            "Q8_0 model should use next_q8_buffer for O weight"
7326        );
7327        assert!(
7328            model_rs.contains("let down_weight = next_q8_buffer"),
7329            "Q8_0 model should use next_q8_buffer for down weight"
7330        );
7331    }
7332
7333    #[test]
7334    fn f32_model_does_not_use_q8_dispatch() {
7335        let config = minimal_config();
7336        let model_rs = generate_model_rs(&config).unwrap();
7337
7338        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
7339        let forward_start = model_rs
7340            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7341            .unwrap();
7342        let forward_body = &model_rs[forward_start..];
7343        let forward_end = forward_body
7344            .find("\n    fn dispatch_")
7345            .unwrap_or(forward_body.len());
7346        let forward_code = &forward_body[..forward_end];
7347
7348        assert!(
7349            !forward_code.contains("dispatch_matmul_q8"),
7350            "f32 model forward should not use dispatch_matmul_q8"
7351        );
7352    }
7353
7354    #[test]
7355    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7356        let config = minimal_q8_config();
7357        let model_rs = generate_model_rs(&config).unwrap();
7358
7359        assert!(
7360            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7361            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7362        );
7363    }
7364
7365    #[test]
7366    fn generated_model_has_double_buffered_prefill() {
7367        let config = minimal_config();
7368        let model_rs = generate_model_rs(&config).unwrap();
7369
7370        // MetalModel should have prev_cmd field for double-buffered prefill
7371        assert!(
7372            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7373            "MetalModel should have prev_cmd field for double-buffered prefill"
7374        );
7375
7376        // Should have forward_prefill method
7377        assert!(
7378            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7379            "MetalModel should have forward_prefill method"
7380        );
7381
7382        // forward() should drain prev_cmd at the start
7383        assert!(
7384            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7385            "forward() should drain prev_cmd from previous prefill"
7386        );
7387    }
7388
7389    #[test]
7390    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7391        let config = minimal_config();
7392        let main_rs = generate_main_rs("test-model", &config).unwrap();
7393
7394        assert!(
7395            main_rs.contains("forward_prefill"),
7396            "main.rs should use forward_prefill for intermediate prompt tokens"
7397        );
7398        assert!(
7399            main_rs.contains("double-buffered"),
7400            "main.rs should document double-buffered prefill"
7401        );
7402    }
7403
7404    #[test]
7405    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7406        let shaders = generate_metal_shaders(&minimal_config());
7407
7408        assert!(
7409            shaders.contains("packed_short4"),
7410            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7411        );
7412        assert!(
7413            shaders.contains("d0[0]"),
7414            "matmul_vec_q8 should index the wide pointer for row 0"
7415        );
7416        assert!(
7417            shaders.contains("as_type<char2>"),
7418            "matmul_vec_q8 should bitcast short lanes to char2"
7419        );
7420        assert!(
7421            shaders.contains("dot("),
7422            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7423        );
7424    }
7425
7426    // ── Q4_0 tests ──────────────────────────────────────────────────────
7427
7428    fn minimal_q4_config() -> ModelConfig {
7429        ModelConfig {
7430            architecture: Architecture::Llama,
7431            hidden_size: 64,
7432            intermediate_size: 128,
7433            num_layers: 2,
7434            num_attention_heads: 4,
7435            num_kv_heads: 4,
7436            head_dim: 16,
7437            vocab_size: 256,
7438            max_seq_len: 512,
7439            rms_norm_eps: 1e-5,
7440            rope_theta: 10000.0,
7441            dtype: DType::Q4_0,
7442            sliding_window_size: None,
7443            qkv_bias: false,
7444        }
7445    }
7446
7447    #[test]
7448    fn generated_shaders_contain_q4_kernel() {
7449        let shaders = generate_metal_shaders(&minimal_config());
7450
7451        assert!(
7452            shaders.contains("kernel void matmul_vec_q4"),
7453            "shaders should contain matmul_vec_q4 kernel"
7454        );
7455        assert!(
7456            shaders.contains("Q4_ROWS_PER_TG"),
7457            "shaders should define Q4_ROWS_PER_TG constant"
7458        );
7459        assert!(
7460            shaders.contains("Q4_ROWS_PER_SG"),
7461            "shaders should define Q4_ROWS_PER_SG constant"
7462        );
7463    }
7464
7465    #[test]
7466    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7467        let shaders = generate_metal_shaders(&minimal_config());
7468
7469        // Q4_0 kernel should use uchar4 for packed byte loads
7470        assert!(
7471            shaders.contains("uchar4"),
7472            "matmul_vec_q4 should use uchar4 for packed byte loads"
7473        );
7474        // Should unpack nibbles with &0xF and >>4
7475        assert!(
7476            shaders.contains("&0xF"),
7477            "matmul_vec_q4 should extract low nibble with &0xF"
7478        );
7479        assert!(
7480            shaders.contains(">>4"),
7481            "matmul_vec_q4 should extract high nibble with >>4"
7482        );
7483        // Should subtract 8 to convert unsigned to signed
7484        assert!(
7485            shaders.contains("-8)"),
7486            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7487        );
7488        // Should use 18-byte block size
7489        assert!(
7490            shaders.contains("blk * 18"),
7491            "matmul_vec_q4 should use 18-byte block stride"
7492        );
7493    }
7494
7495    #[test]
7496    fn q4_model_has_matmul_q4_pipeline() {
7497        let config = minimal_q4_config();
7498        let model_rs = generate_model_rs(&config).unwrap();
7499
7500        assert!(
7501            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7502            "MetalModel should have matmul_q4_pipeline field"
7503        );
7504        assert!(
7505            model_rs.contains("matmul_q4_pipeline,"),
7506            "MetalModel Self should include matmul_q4_pipeline"
7507        );
7508    }
7509
7510    #[test]
7511    fn q4_model_uses_dispatch_matmul_q4() {
7512        let config = minimal_q4_config();
7513        let model_rs = generate_model_rs(&config).unwrap();
7514
7515        assert!(
7516            model_rs.contains("dispatch_matmul_q4"),
7517            "Q4_0 model should use dispatch_matmul_q4 for projections"
7518        );
7519        assert!(
7520            model_rs.contains("fn dispatch_matmul_q4"),
7521            "model.rs should define dispatch_matmul_q4 method"
7522        );
7523    }
7524
7525    #[test]
7526    fn q4_model_loads_raw_bytes_not_dequantized() {
7527        let config = minimal_q4_config();
7528        let model_rs = generate_model_rs(&config).unwrap();
7529
7530        // Should NOT contain dequantization code
7531        assert!(
7532            !model_rs.contains("f16_to_f32"),
7533            "Q4_0 model should not dequantize weights to f32"
7534        );
7535
7536        // Should load raw Q4_0 bytes directly
7537        assert!(
7538            model_rs.contains("total_raw as u64"),
7539            "Q4_0 model should load raw bytes into Metal buffer"
7540        );
7541    }
7542
7543    #[test]
7544    fn q4_model_norms_stay_f32() {
7545        let config = minimal_q4_config();
7546        let model_rs = generate_model_rs(&config).unwrap();
7547
7548        assert!(
7549            model_rs.contains("let attn_norm = next_f32_buffer"),
7550            "attn_norm should use f32 buffer even for Q4_0 models"
7551        );
7552        assert!(
7553            model_rs.contains("let ffn_norm = next_f32_buffer"),
7554            "ffn_norm should use f32 buffer even for Q4_0 models"
7555        );
7556        assert!(
7557            model_rs.contains("let norm_buf = next_f32_buffer"),
7558            "final norm should use f32 buffer even for Q4_0 models"
7559        );
7560    }
7561
7562    #[test]
7563    fn q4_model_uses_fused_weight_loading() {
7564        let config = minimal_q4_config();
7565        let model_rs = generate_model_rs(&config).unwrap();
7566
7567        assert!(
7568            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7569            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7570        );
7571        assert!(
7572            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7573            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7574        );
7575        assert!(
7576            model_rs.contains("let o_weight = next_q4_buffer"),
7577            "Q4_0 model should use next_q4_buffer for O weight"
7578        );
7579        assert!(
7580            model_rs.contains("let down_weight = next_q4_buffer"),
7581            "Q4_0 model should use next_q4_buffer for down weight"
7582        );
7583    }
7584
7585    #[test]
7586    fn attention_flash_batch_kernel_exists() {
7587        // The flash kernel is still wired into the library (pipeline, kernel
7588        // source).  Dispatch is currently routed to the legacy path pending a
7589        // fix for a numerical issue discovered after the prompt-chunking fix.
7590        let config = minimal_config();
7591        let model_rs = generate_model_rs(&config).unwrap();
7592        let shaders = generate_metal_shaders(&config);
7593
7594        assert!(
7595            shaders.contains("kernel void attention_flash_batch"),
7596            "shaders.metal must still contain the attention_flash_batch kernel"
7597        );
7598        assert!(
7599            shaders.contains("FLASH_K_TILE"),
7600            "flash kernel must tile K/V with a FLASH_K_TILE constant"
7601        );
7602        assert!(
7603            model_rs.contains("attention_flash_batch_pipeline"),
7604            "MetalModel must register the flash attention pipeline"
7605        );
7606    }
7607
7608    #[test]
7609    fn forward_prefill_batch_chunks_by_max_batch_size() {
7610        // Regression: prior to v0.6.4 forward_prefill_batch truncated prompts
7611        // longer than MAX_BATCH_SIZE (512) tokens via `.min(MAX_BATCH_SIZE)`,
7612        // silently dropping the middle of long prompts.  Must now loop over
7613        // MAX_BATCH_SIZE-sized chunks and carry KV-cache state across them.
7614        let config = minimal_config();
7615        let model_rs = generate_model_rs(&config).unwrap();
7616        assert!(
7617            model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
7618            "forward_prefill_batch must chunk long prompts"
7619        );
7620        assert!(
7621            !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
7622            "the old truncation path must be gone"
7623        );
7624    }
7625
7626    #[test]
7627    fn qwen2_qkv_bias_wired_through_metal_codegen() {
7628        // Issue #210: the pre-v0.6.2 Metal codegen had zero handling for
7629        // qkv_bias.  Verify that a Qwen2-style config emits the bias buffer,
7630        // loader, pipeline, and dispatch call in the expected places.
7631        let config = ModelConfig {
7632            architecture: Architecture::Qwen2,
7633            qkv_bias: true,
7634            ..minimal_config()
7635        };
7636        let model_rs = generate_model_rs(&config).unwrap();
7637
7638        assert!(
7639            model_rs.contains("qkv_bias: Buffer"),
7640            "Qwen2 LayerBuffers must declare qkv_bias field"
7641        );
7642        assert!(
7643            model_rs.contains("let qkv_bias = next_f32_buffer"),
7644            "Qwen2 layer init must load the bias from the weight blob"
7645        );
7646        assert!(
7647            model_rs.contains("add_bias_batch_pipeline"),
7648            "Qwen2 model struct must include the add_bias_batch_pipeline"
7649        );
7650        assert!(
7651            model_rs.contains("fn dispatch_add_bias_batch"),
7652            "Qwen2 codegen must emit dispatch_add_bias_batch helper"
7653        );
7654        assert!(
7655            model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
7656            "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
7657        );
7658        assert!(
7659            model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
7660            "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
7661        );
7662
7663        // The add_bias_batch MSL kernel must be in the shader source.
7664        let shaders = generate_metal_shaders(&config);
7665        assert!(
7666            shaders.contains("kernel void add_bias_batch"),
7667            "shaders.metal must contain the add_bias_batch kernel"
7668        );
7669    }
7670
7671    #[test]
7672    fn llama_does_not_emit_qkv_bias_machinery() {
7673        // Negative test: non-Qwen2 models must NOT carry the bias dispatch,
7674        // buffer, or pipeline — keeps generated code lean for Llama/Phi/etc.
7675        let config = minimal_config();
7676        assert!(!config.qkv_bias);
7677        let model_rs = generate_model_rs(&config).unwrap();
7678        assert!(
7679            !model_rs.contains("qkv_bias: Buffer"),
7680            "Llama must not have qkv_bias field"
7681        );
7682        assert!(
7683            !model_rs.contains("add_bias_batch_pipeline"),
7684            "Llama must not pull in add_bias_batch_pipeline"
7685        );
7686        assert!(
7687            !model_rs.contains("dispatch_add_bias_batch"),
7688            "Llama must not call dispatch_add_bias_batch"
7689        );
7690    }
7691
7692    #[test]
7693    fn q4_dispatch_helper_takes_compute_encoder_ref() {
7694        let config = minimal_q4_config();
7695        let model_rs = generate_model_rs(&config).unwrap();
7696
7697        assert!(
7698            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
7699            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
7700        );
7701    }
7702
7703    #[test]
7704    fn f32_model_does_not_use_q4_dispatch() {
7705        let config = minimal_config();
7706        let model_rs = generate_model_rs(&config).unwrap();
7707
7708        let forward_start = model_rs
7709            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7710            .unwrap();
7711        let forward_body = &model_rs[forward_start..];
7712        let forward_end = forward_body
7713            .find("\n    fn dispatch_")
7714            .unwrap_or(forward_body.len());
7715        let forward_code = &forward_body[..forward_end];
7716
7717        assert!(
7718            !forward_code.contains("dispatch_matmul_q4"),
7719            "f32 model forward should not use dispatch_matmul_q4"
7720        );
7721    }
7722
7723    #[test]
7724    fn q4_model_lm_head_uses_q4_buffer() {
7725        let config = minimal_q4_config();
7726        let model_rs = generate_model_rs(&config).unwrap();
7727
7728        assert!(
7729            model_rs.contains("let lm_head_buf = next_q4_buffer"),
7730            "Q4_0 model should use next_q4_buffer for lm_head"
7731        );
7732    }
7733
7734    #[test]
7735    fn vec_tile_size_matches_model_dimensions() {
7736        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
7737        let small = minimal_config();
7738        let shaders_small = generate_metal_shaders(&small);
7739        assert!(
7740            shaders_small.contains("vec_tile[128]"),
7741            "vec_tile should be sized to max(hidden, intermediate) = 128"
7742        );
7743
7744        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
7745        let mut large = minimal_config();
7746        large.hidden_size = 2048;
7747        large.intermediate_size = 8192;
7748        let shaders_large = generate_metal_shaders(&large);
7749        assert!(
7750            shaders_large.contains("vec_tile[8192]"),
7751            "vec_tile should be 8192 for models with intermediate=8192"
7752        );
7753        assert!(
7754            !shaders_large.contains("vec_tile[4096]"),
7755            "vec_tile should NOT be hardcoded to 4096"
7756        );
7757    }
7758
7759    #[test]
7760    fn generated_cargo_toml_has_server_deps() {
7761        let toml = generate_cargo_toml("my-model");
7762        assert!(
7763            toml.contains("tiny_http"),
7764            "Cargo.toml should depend on tiny_http"
7765        );
7766        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
7767        assert!(
7768            toml.contains("serde_json"),
7769            "Cargo.toml should depend on serde_json"
7770        );
7771    }
7772
7773    #[test]
7774    fn generated_main_rs_has_serve_mode() {
7775        let config = minimal_config();
7776        let main_rs = generate_main_rs("test-model", &config).unwrap();
7777
7778        assert!(
7779            main_rs.contains("--serve"),
7780            "main.rs should parse --serve flag"
7781        );
7782        assert!(
7783            main_rs.contains("--port"),
7784            "main.rs should parse --port flag"
7785        );
7786        assert!(
7787            main_rs.contains("fn serve("),
7788            "main.rs should define serve function"
7789        );
7790        assert!(
7791            main_rs.contains("tiny_http::Server::http"),
7792            "main.rs should create tiny_http server"
7793        );
7794    }
7795
7796    #[test]
7797    fn generated_main_rs_has_chat_completions_endpoint() {
7798        let config = minimal_config();
7799        let main_rs = generate_main_rs("test-model", &config).unwrap();
7800
7801        assert!(
7802            main_rs.contains("/v1/chat/completions"),
7803            "main.rs should handle /v1/chat/completions endpoint"
7804        );
7805        assert!(
7806            main_rs.contains("/v1/models"),
7807            "main.rs should handle /v1/models endpoint"
7808        );
7809        assert!(
7810            main_rs.contains("/health"),
7811            "main.rs should handle /health endpoint"
7812        );
7813    }
7814
7815    #[test]
7816    fn generated_main_rs_has_sse_streaming() {
7817        let config = minimal_config();
7818        let main_rs = generate_main_rs("test-model", &config).unwrap();
7819
7820        assert!(
7821            main_rs.contains("text/event-stream"),
7822            "main.rs should set SSE content type for streaming"
7823        );
7824        assert!(
7825            main_rs.contains("chat.completion.chunk"),
7826            "main.rs should emit SSE chunks"
7827        );
7828        assert!(
7829            main_rs.contains("[DONE]"),
7830            "main.rs should emit [DONE] sentinel"
7831        );
7832    }
7833
7834    #[test]
7835    fn generated_main_rs_has_chat_message_formatting() {
7836        let config = minimal_config();
7837        let main_rs = generate_main_rs("test-model", &config).unwrap();
7838
7839        assert!(
7840            main_rs.contains("fn format_chat_messages"),
7841            "main.rs should define format_chat_messages function"
7842        );
7843        assert!(
7844            main_rs.contains("<|im_start|>"),
7845            "main.rs should use ChatML format"
7846        );
7847        assert!(
7848            main_rs.contains("<|im_end|>"),
7849            "main.rs should use ChatML format"
7850        );
7851    }
7852
7853    #[test]
7854    fn generated_main_rs_has_request_types() {
7855        let config = minimal_config();
7856        let main_rs = generate_main_rs("test-model", &config).unwrap();
7857
7858        assert!(
7859            main_rs.contains("struct ChatRequest"),
7860            "main.rs should define ChatRequest struct"
7861        );
7862        assert!(
7863            main_rs.contains("struct ChatMessage"),
7864            "main.rs should define ChatMessage struct"
7865        );
7866        assert!(
7867            main_rs.contains("Deserialize"),
7868            "main.rs should derive Deserialize for request types"
7869        );
7870    }
7871
7872    #[test]
7873    fn generated_model_has_reset_method() {
7874        let config = minimal_config();
7875        let model_rs = generate_model_rs(&config).unwrap();
7876
7877        assert!(
7878            model_rs.contains("pub fn reset(&mut self)"),
7879            "model.rs should have a reset() method for multi-request serving"
7880        );
7881        assert!(
7882            model_rs.contains("self.pos = 0"),
7883            "reset() should reset position to 0"
7884        );
7885    }
7886
7887    #[test]
7888    fn generated_main_rs_cli_mode_still_works() {
7889        let config = minimal_config();
7890        let main_rs = generate_main_rs("test-model", &config).unwrap();
7891
7892        // CLI mode should still be functional
7893        assert!(
7894            main_rs.contains("fn cli_mode("),
7895            "main.rs should define cli_mode function"
7896        );
7897        assert!(
7898            main_rs.contains("model.forward"),
7899            "main.rs should call model.forward"
7900        );
7901        assert!(
7902            main_rs.contains("model.forward_prefill"),
7903            "main.rs should call model.forward_prefill"
7904        );
7905    }
7906
7907    // ── Batched prefill tests ──────────────────────────────────────────
7908
7909    #[test]
7910    fn generated_shaders_contain_batch_kernels() {
7911        let shaders = generate_metal_shaders(&minimal_config());
7912
7913        assert!(
7914            shaders.contains("kernel void matmul_vec_batch"),
7915            "shaders should contain matmul_vec_batch kernel"
7916        );
7917        assert!(
7918            shaders.contains("kernel void matmul_vec_q8_batch"),
7919            "shaders should contain matmul_vec_q8_batch kernel"
7920        );
7921        assert!(
7922            shaders.contains("kernel void matmul_q8_gemm_batch"),
7923            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
7924        );
7925        assert!(
7926            shaders.contains("kernel void matmul_vec_q4_batch"),
7927            "shaders should contain matmul_vec_q4_batch kernel"
7928        );
7929        assert!(
7930            shaders.contains("kernel void rms_norm_batch"),
7931            "shaders should contain rms_norm_batch kernel"
7932        );
7933        assert!(
7934            shaders.contains("kernel void silu_mul_fused_batch"),
7935            "shaders should contain silu_mul_fused_batch kernel"
7936        );
7937        assert!(
7938            shaders.contains("kernel void add_inplace_batch"),
7939            "shaders should contain add_inplace_batch kernel"
7940        );
7941        assert!(
7942            shaders.contains("kernel void copy_embedding_batch"),
7943            "shaders should contain copy_embedding_batch kernel"
7944        );
7945    }
7946
7947    #[test]
7948    fn generated_model_has_batch_pipelines() {
7949        let config = minimal_config();
7950        let model_rs = generate_model_rs(&config).unwrap();
7951
7952        for pipeline in &[
7953            "matmul_batch_pipeline",
7954            "matmul_q8_batch_pipeline",
7955            "matmul_q8_gemm_batch_pipeline",
7956            "matmul_q4_batch_pipeline",
7957            "rms_norm_batch_pipeline",
7958            "rope_batch_pipeline",
7959            "silu_mul_fused_batch_pipeline",
7960            "add_inplace_batch_pipeline",
7961            "copy_embedding_batch_pipeline",
7962            "attention_batch_pipeline",
7963            "copy_kv_batch_pipeline",
7964        ] {
7965            assert!(
7966                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
7967                "MetalModel should have {pipeline} field"
7968            );
7969        }
7970    }
7971
7972    #[test]
7973    fn generated_model_has_batch_buffers() {
7974        let config = minimal_config();
7975        let model_rs = generate_model_rs(&config).unwrap();
7976
7977        for buf in &[
7978            "batch_hidden_buf",
7979            "batch_residual_buf",
7980            "batch_qkv_buf",
7981            "batch_attn_out_buf",
7982            "batch_attn_proj_buf",
7983            "batch_gate_up_buf",
7984            "batch_ffn_hidden_buf",
7985            "batch_ffn_out_buf",
7986            "batch_tokens_buf",
7987            "batch_positions_buf",
7988        ] {
7989            assert!(
7990                model_rs.contains(&format!("{buf}: Buffer")),
7991                "MetalModel should have {buf} field"
7992            );
7993        }
7994    }
7995
7996    #[test]
7997    fn generated_model_has_forward_prefill_batch() {
7998        let config = minimal_config();
7999        let model_rs = generate_model_rs(&config).unwrap();
8000
8001        assert!(
8002            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8003            "MetalModel should have forward_prefill_batch method"
8004        );
8005
8006        // forward_prefill should delegate to forward_prefill_batch
8007        assert!(
8008            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8009            "forward_prefill should delegate to forward_prefill_batch"
8010        );
8011    }
8012
8013    #[test]
8014    fn generated_model_has_max_batch_size_constant() {
8015        let config = minimal_config();
8016        let model_rs = generate_model_rs(&config).unwrap();
8017
8018        assert!(
8019            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8020            "model.rs should define MAX_BATCH_SIZE constant"
8021        );
8022    }
8023
8024    #[test]
8025    fn forward_prefill_batch_uses_batch_dispatch() {
8026        let config = minimal_config();
8027        let model_rs = generate_model_rs(&config).unwrap();
8028
8029        let batch_start = model_rs
8030            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8031            .unwrap();
8032        let batch_body = &model_rs[batch_start..];
8033        let batch_end = batch_body
8034            .find("\n    pub fn reset")
8035            .unwrap_or(batch_body.len());
8036        let batch_code = &batch_body[..batch_end];
8037
8038        // Should use batched dispatch methods
8039        assert!(
8040            batch_code.contains("dispatch_rms_norm_batch"),
8041            "forward_prefill_batch should use dispatch_rms_norm_batch"
8042        );
8043        assert!(
8044            batch_code.contains("dispatch_copy_embedding_batch"),
8045            "forward_prefill_batch should use dispatch_copy_embedding_batch"
8046        );
8047        assert!(
8048            batch_code.contains("dispatch_silu_mul_fused_batch"),
8049            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8050        );
8051        // Should use batched causal attention dispatch
8052        assert!(
8053            batch_code.contains("dispatch_attention_batch"),
8054            "forward_prefill_batch should use dispatch_attention_batch"
8055        );
8056        // Should use fused KV cache copy (both K and V in one dispatch)
8057        assert!(
8058            batch_code.contains("dispatch_copy_kv_both_batch"),
8059            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8060        );
8061        // Should use fused RoPE Q+K dispatch
8062        assert!(
8063            batch_code.contains("dispatch_rope_qk_batch"),
8064            "forward_prefill_batch should use dispatch_rope_qk_batch"
8065        );
8066    }
8067
8068    #[test]
8069    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8070        let config = minimal_q8_config();
8071        let model_rs = generate_model_rs(&config).unwrap();
8072
8073        let batch_start = model_rs
8074            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8075            .unwrap();
8076        let batch_body = &model_rs[batch_start..];
8077        let batch_end = batch_body
8078            .find("\n    pub fn reset")
8079            .unwrap_or(batch_body.len());
8080        let batch_code = &batch_body[..batch_end];
8081
8082        assert!(
8083            batch_code.contains("dispatch_matmul_q8_batch"),
8084            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8085        );
8086    }
8087
8088    #[test]
8089    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8090        let config = minimal_q4_config();
8091        let model_rs = generate_model_rs(&config).unwrap();
8092
8093        let batch_start = model_rs
8094            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8095            .unwrap();
8096        let batch_body = &model_rs[batch_start..];
8097        let batch_end = batch_body
8098            .find("\n    pub fn reset")
8099            .unwrap_or(batch_body.len());
8100        let batch_code = &batch_body[..batch_end];
8101
8102        assert!(
8103            batch_code.contains("dispatch_matmul_q4_batch"),
8104            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8105        );
8106    }
8107
8108    #[test]
8109    fn generated_main_rs_uses_batched_prefill() {
8110        let config = minimal_config();
8111        let main_rs = generate_main_rs("test-model", &config).unwrap();
8112
8113        assert!(
8114            main_rs.contains("forward_prefill_batch"),
8115            "main.rs should use forward_prefill_batch for prompt tokens"
8116        );
8117    }
8118
8119    #[test]
8120    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8121        let config = minimal_config();
8122        let model_rs = generate_model_rs(&config).unwrap();
8123
8124        let batch_start = model_rs
8125            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8126            .unwrap();
8127        let batch_body = &model_rs[batch_start..];
8128        let batch_end = batch_body
8129            .find("\n    pub fn reset")
8130            .unwrap_or(batch_body.len());
8131        let batch_code = &batch_body[..batch_end];
8132
8133        assert!(
8134            batch_code.contains("dispatch_matmul_batch"),
8135            "f32 forward_prefill_batch should use dispatch_matmul_batch"
8136        );
8137        // Should NOT use Q8 or Q4 batch dispatch
8138        assert!(
8139            !batch_code.contains("dispatch_matmul_q8_batch"),
8140            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8141        );
8142        assert!(
8143            !batch_code.contains("dispatch_matmul_q4_batch"),
8144            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8145        );
8146    }
8147
8148    #[test]
8149    fn forward_uses_cpu_embedding_lookup() {
8150        let config = minimal_config();
8151        let model_rs = generate_model_rs(&config).unwrap();
8152
8153        // Find just the forward() body (not forward_profile)
8154        let forward_start = model_rs
8155            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8156            .unwrap();
8157        let forward_body = &model_rs[forward_start..];
8158        let forward_end = forward_body
8159            .find("\n    pub fn forward_profile")
8160            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8161            .unwrap_or(forward_body.len());
8162        let forward_code = &forward_body[..forward_end];
8163
8164        // forward() should use CPU memcpy for embedding lookup (unified memory)
8165        assert!(
8166            forward_code.contains("embed_buf.contents()"),
8167            "forward() should access embed_buf via CPU unified memory for embedding lookup"
8168        );
8169        assert!(
8170            forward_code.contains("copy_nonoverlapping"),
8171            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8172        );
8173        // forward() should NOT use GPU dispatch for embedding
8174        assert!(
8175            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8176            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8177        );
8178    }
8179
8180    #[test]
8181    fn forward_profile_method_exists() {
8182        let config = minimal_config();
8183        let model_rs = generate_model_rs(&config).unwrap();
8184
8185        assert!(
8186            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8187            "MetalModel should have forward_profile() method"
8188        );
8189        // Profile method should print timing information
8190        assert!(
8191            model_rs.contains("[profile]"),
8192            "forward_profile() should print timing with [profile] prefix"
8193        );
8194        assert!(
8195            model_rs.contains("d_embed"),
8196            "forward_profile() should measure embedding time"
8197        );
8198        assert!(
8199            model_rs.contains("d_layers"),
8200            "forward_profile() should measure layer time"
8201        );
8202        assert!(
8203            model_rs.contains("d_logits"),
8204            "forward_profile() should measure logits time"
8205        );
8206    }
8207
8208    #[test]
8209    fn generated_cli_has_profile_flag() {
8210        let config = minimal_config();
8211        let main_rs = generate_main_rs("test-model", &config).unwrap();
8212
8213        assert!(
8214            main_rs.contains("--profile"),
8215            "CLI should support --profile flag"
8216        );
8217        assert!(
8218            main_rs.contains("forward_profile"),
8219            "CLI should call forward_profile when --profile is set"
8220        );
8221    }
8222
8223    #[test]
8224    fn generated_cli_has_thermal_yield() {
8225        let config = minimal_config();
8226        let main_rs = generate_main_rs("test-model", &config).unwrap();
8227
8228        assert!(
8229            main_rs.contains("yield_now()"),
8230            "CLI generation loop should include thread::yield_now() for thermal management"
8231        );
8232    }
8233
8234    // ── Real-world validation tests ──────────────────────────────────────
8235
8236    #[test]
8237    fn generated_forward_handles_single_token_prompt() {
8238        // With a single token (the first prompt token), forward() should work
8239        // at pos=0 where seq_len=1. The attention kernel must handle the case
8240        // where there is only one KV entry (no prefill context).
8241        let config = minimal_config();
8242        let model_rs = generate_model_rs(&config).unwrap();
8243
8244        // The forward function should accept any u32 token_id (no minimum pos guard)
8245        let forward_start = model_rs
8246            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8247            .expect("forward() must exist");
8248        let forward_body = &model_rs[forward_start..forward_start + 400];
8249
8250        // Should NOT require pos > 0 or seq_len > 1
8251        assert!(
8252            !forward_body.contains("assert!(self.pos > 0"),
8253            "forward() must accept pos=0 (first token with no prefill)"
8254        );
8255
8256        // The attention kernel should handle seq_len=1 via the pos field
8257        assert!(
8258            model_rs.contains("self.pos"),
8259            "forward() should use self.pos to track sequence position"
8260        );
8261    }
8262
8263    #[test]
8264    fn generated_reset_clears_kv_cache_position() {
8265        // After reset(), the model should be in a clean state. The pos field
8266        // must be 0 so new generation starts from scratch.
8267        let config = minimal_config();
8268        let model_rs = generate_model_rs(&config).unwrap();
8269
8270        let reset_start = model_rs
8271            .find("pub fn reset(&mut self)")
8272            .expect("reset() must exist");
8273        let reset_body = &model_rs[reset_start..reset_start + 200];
8274
8275        // Reset must zero the position counter
8276        assert!(
8277            reset_body.contains("self.pos = 0"),
8278            "reset() must set self.pos = 0"
8279        );
8280
8281        // Verify reset clears prev_cmd (double-buffering state)
8282        assert!(
8283            reset_body.contains("self.prev_cmd = None"),
8284            "reset() should clear prev_cmd for clean command buffer state"
8285        );
8286    }
8287
8288    #[test]
8289    fn generated_serve_handles_empty_messages_gracefully() {
8290        // The serve endpoint should not crash when receiving an empty messages array.
8291        // The format_chat_messages function should handle this gracefully.
8292        let config = minimal_config();
8293        let main_rs = generate_main_rs("test-model", &config).unwrap();
8294
8295        // The format_chat_messages function should exist and handle empty input
8296        let format_fn_start = main_rs
8297            .find("fn format_chat_messages")
8298            .expect("format_chat_messages must exist");
8299        let format_fn_body =
8300            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8301
8302        // It should iterate over messages (an empty slice produces an empty loop)
8303        assert!(
8304            format_fn_body.contains("for msg in messages"),
8305            "format_chat_messages should iterate over the messages slice"
8306        );
8307        // It should always append the assistant prompt suffix
8308        assert!(
8309            format_fn_body.contains("<|im_start|>assistant"),
8310            "format_chat_messages should always append assistant prompt header"
8311        );
8312
8313        // The serve function should call model.reset() before each request
8314        let serve_fn_start = main_rs
8315            .find("fn serve(")
8316            .expect("serve function must exist");
8317        let serve_fn_body = &main_rs[serve_fn_start..];
8318        assert!(
8319            serve_fn_body.contains("model.reset()"),
8320            "serve function should reset model between requests"
8321        );
8322    }
8323
8324    #[test]
8325    fn generated_model_forward_increments_pos() {
8326        // Each forward() call must increment self.pos so the next token
8327        // uses the correct RoPE position and KV cache offset.
8328        let config = minimal_config();
8329        let model_rs = generate_model_rs(&config).unwrap();
8330
8331        let forward_start = model_rs
8332            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8333            .unwrap();
8334        let forward_body = &model_rs[forward_start..];
8335        let forward_end = forward_body
8336            .find("\n    pub fn forward_profile")
8337            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8338            .or_else(|| forward_body.find("\n    fn dispatch_"))
8339            .unwrap_or(forward_body.len());
8340        let forward_code = &forward_body[..forward_end];
8341
8342        assert!(
8343            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8344            "forward() must increment self.pos after processing a token"
8345        );
8346    }
8347}