Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    // The attention-scores shared array must be at least effective_seq_len
122    // elements; anything smaller silently overflows for long prompts.  For
123    // small-context models (135M at 2K), a smaller array saves threadgroup
124    // memory and improves occupancy — so we size it precisely.
125    let attn_scores_size = config.max_seq_len.min(4096);
126    r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154    device const float* matrix [[buffer(0)]],
155    device const float* vector [[buffer(1)]],
156    device float* output       [[buffer(2)]],
157    constant uint& rows        [[buffer(3)]],
158    constant uint& cols        [[buffer(4)]],
159    uint tgid [[threadgroup_position_in_grid]],
160    uint tid [[thread_index_in_threadgroup]],
161    uint simd_lane [[thread_index_in_simdgroup]],
162    uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164    // Cooperatively load vector into threadgroup shared memory
165    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166    for (uint i = tid; i < cols; i += 256) {
167        vec_tile[i] = vector[i];
168    }
169    threadgroup_barrier(mem_flags::mem_threadgroup);
170
171    // Each simdgroup handles 8 consecutive rows
172    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173    if (row_base >= rows) return;
174
175    uint base0 = row_base * cols;
176    uint base1 = (row_base + 1) * cols;
177    uint base2 = (row_base + 2) * cols;
178    uint base3 = (row_base + 3) * cols;
179    uint base4 = (row_base + 4) * cols;
180    uint base5 = (row_base + 5) * cols;
181    uint base6 = (row_base + 6) * cols;
182    uint base7 = (row_base + 7) * cols;
183
184    // float4 vectorized accumulation across 8 rows
185    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
186    float4 sum4_0 = float4(0.0f);
187    float4 sum4_1 = float4(0.0f);
188    float4 sum4_2 = float4(0.0f);
189    float4 sum4_3 = float4(0.0f);
190    float4 sum4_4 = float4(0.0f);
191    float4 sum4_5 = float4(0.0f);
192    float4 sum4_6 = float4(0.0f);
193    float4 sum4_7 = float4(0.0f);
194
195    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205    }
206
207    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216    // Handle remaining elements (cols not divisible by 128)
217    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218        float vv = vec_tile[j];
219        sum0 += matrix[base0 + j] * vv;
220        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227    }
228
229    // Simdgroup hardware warp-level reduction
230    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235    // Only first lane writes the results
236    if (simd_lane == 0) {
237        if (row_base     < rows) output[row_base]     = sum0;
238        if (row_base + 1 < rows) output[row_base + 1] = sum1;
239        if (row_base + 2 < rows) output[row_base + 2] = sum2;
240        if (row_base + 3 < rows) output[row_base + 3] = sum3;
241        if (row_base + 4 < rows) output[row_base + 4] = sum4;
242        if (row_base + 5 < rows) output[row_base + 5] = sum5;
243        if (row_base + 6 < rows) output[row_base + 6] = sum6;
244        if (row_base + 7 < rows) output[row_base + 7] = sum7;
245    }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253    device const float* input   [[buffer(0)]],
254    device const float* weight  [[buffer(1)]],
255    device float* output        [[buffer(2)]],
256    constant uint& n            [[buffer(3)]],
257    constant float& eps         [[buffer(4)]],
258    uint tid [[thread_index_in_threadgroup]])
259{
260    // Each thread accumulates partial sum-of-squares
261    float sum_sq = 0.0f;
262    for (uint i = tid; i < n; i += 256) {
263        float v = input[i];
264        sum_sq += v * v;
265    }
266
267    // Simdgroup-level reduction (hardware warp sum)
268    sum_sq = simd_sum(sum_sq);
269
270    // Cross-simdgroup reduction via shared memory
271    threadgroup float shared[8];
272    uint simd_id = tid / 32;
273    uint simd_lane = tid % 32;
274    if (simd_lane == 0) {
275        shared[simd_id] = sum_sq;
276    }
277    threadgroup_barrier(mem_flags::mem_threadgroup);
278
279    // First thread computes final inverse RMS
280    if (tid == 0) {
281        float total = 0.0f;
282        for (uint i = 0; i < 8; i++) {
283            total += shared[i];
284        }
285        shared[0] = fast::rsqrt(total / float(n) + eps);
286    }
287    threadgroup_barrier(mem_flags::mem_threadgroup);
288
289    float inv_rms = shared[0];
290
291    // Normalize
292    for (uint i = tid; i < n; i += 256) {
293        output[i] = input[i] * inv_rms * weight[i];
294    }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301    device float* data        [[buffer(0)]],
302    constant uint& num_heads  [[buffer(1)]],
303    constant uint& head_dim   [[buffer(2)]],
304    constant uint& pos        [[buffer(3)]],
305    constant float& theta     [[buffer(4)]],
306    uint id [[thread_position_in_grid]])
307{
308    uint half_dim = head_dim / 2;
309    uint total_pairs = num_heads * half_dim;
310    if (id >= total_pairs) return;
311
312    uint h = id / half_dim;
313    uint i = id % half_dim;
314    uint off = h * head_dim;
315
316    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317    float angle = float(pos) * freq;
318    float c = cos(angle);
319    float s = sin(angle);
320
321    float x0 = data[off + 2 * i];
322    float x1 = data[off + 2 * i + 1];
323    data[off + 2 * i]     = x0 * c - x1 * s;
324    data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331    device float* data       [[buffer(0)]],
332    constant uint& n         [[buffer(1)]],
333    uint tid [[thread_index_in_threadgroup]],
334    uint tg_size [[threads_per_threadgroup]])
335{
336    threadgroup float shared_val[256];
337
338    // Pass 1: find max
339    float local_max = -INFINITY;
340    for (uint i = tid; i < n; i += tg_size) {
341        local_max = max(local_max, data[i]);
342    }
343    shared_val[tid] = local_max;
344    threadgroup_barrier(mem_flags::mem_threadgroup);
345
346    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347        if (tid < stride) {
348            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349        }
350        threadgroup_barrier(mem_flags::mem_threadgroup);
351    }
352    float max_val = shared_val[0];
353    threadgroup_barrier(mem_flags::mem_threadgroup);
354
355    // Pass 2: exp and sum
356    float local_sum = 0.0f;
357    for (uint i = tid; i < n; i += tg_size) {
358        float e = fast::exp(data[i] - max_val);
359        data[i] = e;
360        local_sum += e;
361    }
362    shared_val[tid] = local_sum;
363    threadgroup_barrier(mem_flags::mem_threadgroup);
364
365    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366        if (tid < stride) {
367            shared_val[tid] += shared_val[tid + stride];
368        }
369        threadgroup_barrier(mem_flags::mem_threadgroup);
370    }
371    float inv_sum = 1.0f / shared_val[0];
372    threadgroup_barrier(mem_flags::mem_threadgroup);
373
374    // Pass 3: normalize
375    for (uint i = tid; i < n; i += tg_size) {
376        data[i] *= inv_sum;
377    }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384    device const float* gate [[buffer(0)]],
385    device const float* up   [[buffer(1)]],
386    device float* output     [[buffer(2)]],
387    constant uint& n         [[buffer(3)]],
388    uint id [[thread_position_in_grid]])
389{
390    if (id >= n) return;
391    float g = gate[id];
392    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397//   gate = gate_up[0..n], up = gate_up[n..2*n]
398//   output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400    device const float* gate_up [[buffer(0)]],
401    device float* output        [[buffer(1)]],
402    constant uint& n            [[buffer(2)]],
403    uint id [[thread_position_in_grid]])
404{
405    if (id >= n) return;
406    float g = gate_up[id];
407    float u = gate_up[n + id];
408    output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414    device const float* a  [[buffer(0)]],
415    device const float* b  [[buffer(1)]],
416    device float* output   [[buffer(2)]],
417    constant uint& n       [[buffer(3)]],
418    uint id [[thread_position_in_grid]])
419{
420    if (id >= n) return;
421    output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428    device const float* src [[buffer(0)]],
429    device float* dst       [[buffer(1)]],
430    constant uint& count    [[buffer(2)]],
431    uint id [[thread_position_in_grid]])
432{
433    if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440    device const float* src     [[buffer(0)]],
441    device float* dst           [[buffer(1)]],
442    constant uint& src_offset   [[buffer(2)]],  // in floats
443    constant uint& count        [[buffer(3)]],
444    uint id [[thread_position_in_grid]])
445{
446    if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── copy_f32_to_f16_offset ──────────────────────────────────────────────
450// Copy f32 elements from src into a half-typed dst, converting to half on
451// write.  Used by the single-token decode path to append a new K/V vector
452// to the f16 KV cache.  Byte offsets into src/dst are supplied via the
453// Metal buffer binding offsets — no in-kernel offsets needed.
454kernel void copy_f32_to_f16_offset(
455    device const float* src     [[buffer(0)]],
456    device half* dst            [[buffer(1)]],
457    constant uint& count        [[buffer(2)]],
458    uint id [[thread_position_in_grid]])
459{
460    if (id < count) dst[id] = half(src[id]);
461}
462
463// ── add_inplace ─────────────────────────────────────────────────────────
464// In-place residual connection: a[i] += b[i]
465// Avoids a separate blit copy for residual add, reducing encoder overhead.
466kernel void add_inplace(
467    device float* a        [[buffer(0)]],
468    device const float* b  [[buffer(1)]],
469    constant uint& n       [[buffer(2)]],
470    uint id [[thread_position_in_grid]])
471{
472    if (id >= n) return;
473    a[id] += b[id];
474}
475
476// ── matmul_vec_q8 ─────────────────────────────────────────────────────
477// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
478// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
479// Operates directly on quantized weights to halve memory bandwidth vs f32,
480// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
481//
482// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
483// because int8->float conversion doubles register demand.  Fully unrolled
484// inner loop with float4 vector loads from shared memory eliminates loop
485// overhead and enables better instruction scheduling.
486// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
487constant constexpr uint Q8_ROWS_PER_SG = 4;
488constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
489
490// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
491// doubles ALU work, so the same register budget applies).
492constant constexpr uint Q4_ROWS_PER_SG = 4;
493constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
494
495kernel void matmul_vec_q8(
496    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
497    device const float* vector   [[buffer(1)]],  // f32 input
498    device float* output         [[buffer(2)]],
499    constant uint& rows          [[buffer(3)]],
500    constant uint& cols          [[buffer(4)]],  // number of elements per row
501    uint tgid [[threadgroup_position_in_grid]],
502    uint tid [[thread_index_in_threadgroup]],
503    uint simd_lane [[thread_index_in_simdgroup]],
504    uint simd_id [[simdgroup_index_in_threadgroup]])
505{
506    // Load vector into shared memory
507    threadgroup float vec_tile[VEC_TILE_SIZE];
508    for (uint i = tid; i < cols; i += 256) {
509        vec_tile[i] = vector[i];
510    }
511    threadgroup_barrier(mem_flags::mem_threadgroup);
512
513    // Each simdgroup handles 4 consecutive rows (lower register pressure)
514    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
515    if (row_base >= rows) return;
516
517    // Q8_0: each block is 34 bytes for 32 elements
518    uint blocks_per_row = cols / 32;
519    uint row_bytes = blocks_per_row * 34;
520
521    // Pointers to each row's Q8_0 data
522    device const uchar* r0 = matrix + row_base * row_bytes;
523    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
524    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
525    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
526
527    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
528
529    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
530        uint bb = blk * 34;
531        uint vb = blk * 32;
532
533        // Prefetch all 4 scales
534        float sc0 = float(*(device const half*)(r0 + bb));
535        float sc1 = float(*(device const half*)(r1 + bb));
536        float sc2 = float(*(device const half*)(r2 + bb));
537        float sc3 = float(*(device const half*)(r3 + bb));
538
539        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
540        // Q8_0 block layout where the int8 data starts at offset +2 from a
541        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
542        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
543        // reduction in memory transactions. Metal's char16/packed_char16 are
544        // reserved types and packed_*int4 require >=4-byte alignment which
545        // this layout does not provide, so packed_short4 is the widest valid
546        // vectorized load.
547        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
548        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
549        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
550        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
551
552        // Load all 8 float4 vector values for this 32-element block from shared memory
553        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
554        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
555        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
556        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
557        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
558        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
559        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
560        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
561
562        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
563        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
564        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
565            short4 _s = short4(SHORT4); \
566            char2 _a = as_type<char2>(_s.x); \
567            char2 _b = as_type<char2>(_s.y); \
568            char2 _c = as_type<char2>(_s.z); \
569            char2 _d = as_type<char2>(_s.w); \
570            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
571            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
572        }
573
574        float4 f0, f1;
575        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
576
577        // Row 0: 4 short4 loads cover 32 int8 weights
578        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
579        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
580        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
581        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
582
583        // Row 1
584        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
585        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
586        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
587        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
588
589        // Row 2
590        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
591        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
592        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
593        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
594
595        // Row 3
596        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
597        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
598        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
599        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
600
601        #undef Q8_UNPACK8
602
603        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
604    }
605
606    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
607    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
608
609    if (simd_lane == 0) {
610        if (row_base     < rows) output[row_base]     = sum0;
611        if (row_base + 1 < rows) output[row_base + 1] = sum1;
612        if (row_base + 2 < rows) output[row_base + 2] = sum2;
613        if (row_base + 3 < rows) output[row_base + 3] = sum3;
614    }
615}
616
617// ── matmul_vec_q4 ─────────────────────────────────────────────────────
618// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
619// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
620// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
621// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
622//
623// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
624// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
625kernel void matmul_vec_q4(
626    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
627    device const float* vector   [[buffer(1)]],  // f32 input
628    device float* output         [[buffer(2)]],
629    constant uint& rows          [[buffer(3)]],
630    constant uint& cols          [[buffer(4)]],  // number of elements per row
631    uint tgid [[threadgroup_position_in_grid]],
632    uint tid [[thread_index_in_threadgroup]],
633    uint simd_lane [[thread_index_in_simdgroup]],
634    uint simd_id [[simdgroup_index_in_threadgroup]])
635{
636    // Load vector into shared memory
637    threadgroup float vec_tile[VEC_TILE_SIZE];
638    for (uint i = tid; i < cols; i += 256) {
639        vec_tile[i] = vector[i];
640    }
641    threadgroup_barrier(mem_flags::mem_threadgroup);
642
643    // Each simdgroup handles 4 consecutive rows
644    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
645    if (row_base >= rows) return;
646
647    // Q4_0: each block is 18 bytes for 32 elements
648    uint blocks_per_row = cols / 32;
649    uint row_bytes = blocks_per_row * 18;
650
651    // Pointers to each row's Q4_0 data
652    device const uchar* r0 = matrix + row_base * row_bytes;
653    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
654    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
655    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
656
657    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
658
659    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
660        uint bb = blk * 18;
661        uint vb = blk * 32;
662
663        // Prefetch all 4 scales
664        float sc0 = float(*(device const half*)(r0 + bb));
665        float sc1 = float(*(device const half*)(r1 + bb));
666        float sc2 = float(*(device const half*)(r2 + bb));
667        float sc3 = float(*(device const half*)(r3 + bb));
668
669        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
670        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
671        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
672        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
673        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
674
675        // Load 8 float4 vector values for 32 elements from shared memory
676        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
677        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
678        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
679        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
680        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
681        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
682        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
683        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
684        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
685
686        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
687        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
688        float bd0=0, bd1=0, bd2=0, bd3=0;
689        uchar4 b;
690
691        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
692        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
693                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
694                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
695                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
696        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
697                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
698                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
699                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
700        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
701                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
702                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
703                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
704        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
705                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
706                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
707                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
708
709        // Row 1
710        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
711                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
712                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
713                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
714        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
715                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
716                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
717                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
718        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
719                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
720                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
721                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
722        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
723                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
724                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
725                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
726
727        // Row 2
728        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
729                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
730                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
731                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
732        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
733                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
734                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
735                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
736        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
737                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
738                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
739                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
740        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
741                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
742                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
743                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
744
745        // Row 3
746        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
747                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
748                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
749                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
750        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
751                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
752                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
753                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
754        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
755                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
756                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
757                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
758        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
759                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
760                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
761                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
762
763        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
764    }
765
766    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
767    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
768
769    if (simd_lane == 0) {
770        if (row_base     < rows) output[row_base]     = sum0;
771        if (row_base + 1 < rows) output[row_base + 1] = sum1;
772        if (row_base + 2 < rows) output[row_base + 2] = sum2;
773        if (row_base + 3 < rows) output[row_base + 3] = sum3;
774    }
775}
776
777// ── attention ───────────────────────────────────────────────────────────
778// Single-query attention with simdgroup cooperative reductions.
779// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
780// with simd_max/simd_sum reductions, then weighted sum of V.
781// Each threadgroup handles one head with 256 threads (8 simdgroups).
782//
783// Buffers:
784//   q:       [num_heads * head_dim]       current query
785//   k_cache: [max_seq_len * num_kv_heads * head_dim]
786//   v_cache: [max_seq_len * num_kv_heads * head_dim]
787//   output:  [num_heads * head_dim]
788kernel void attention(
789    device const float* q        [[buffer(0)]],
790    device const half*  k_cache  [[buffer(1)]],
791    device const half*  v_cache  [[buffer(2)]],
792    device float* output         [[buffer(3)]],
793    constant uint& seq_len       [[buffer(4)]],
794    constant uint& num_heads     [[buffer(5)]],
795    constant uint& num_kv_heads  [[buffer(6)]],
796    constant uint& head_dim      [[buffer(7)]],
797    uint tgid [[threadgroup_position_in_grid]],
798    uint tid [[thread_index_in_threadgroup]],
799    uint simd_lane [[thread_index_in_simdgroup]],
800    uint simd_id [[simdgroup_index_in_threadgroup]])
801{
802    uint head = tgid;
803    if (head >= num_heads) return;
804    uint kv_head = head / (num_heads / num_kv_heads);
805
806    uint q_off = head * head_dim;
807
808    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
809    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
810    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
811    // most generation steps have short effective context.
812    threadgroup float scores[ATTN_SCORES_SIZE];  // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
813
814    // Q·K^T with half4/float4 vectorized loads.
815    // Each simdgroup handles one s; 32 lanes cover head_dim in chunks of 4.
816    // For head_dim=128, every lane does exactly one half4 load (no loop).
817    // For head_dim=64, 16 lanes active; for head_dim=96, 24 lanes.
818    uint head_dim4 = head_dim / 4;
819    for (uint s = simd_id; s < seq_len; s += 8) {
820        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
821        float dot = 0.0;
822        for (uint d4 = simd_lane; d4 < head_dim4; d4 += 32) {
823            uint d = d4 * 4;
824            float4 q4 = *(device const float4*)(q + q_off + d);
825            float4 k4 = float4(*(device const half4*)(k_cache + k_off + d));
826            dot += q4.x * k4.x + q4.y * k4.y + q4.z * k4.z + q4.w * k4.w;
827        }
828        // Scalar fallback for head_dim not divisible by 4 (unused for all
829        // current models: head_dim ∈ {64, 96, 128} are all multiples of 4).
830        for (uint d = head_dim4 * 4 + simd_lane; d < head_dim; d += 32) {
831            dot += q[q_off + d] * float(k_cache[k_off + d]);
832        }
833        dot = simd_sum(dot);
834        if (simd_lane == 0) {
835            scores[s] = dot * fast::rsqrt(float(head_dim));
836        }
837    }
838    threadgroup_barrier(mem_flags::mem_threadgroup);
839
840    // Step 2: Softmax over scores (cooperative)
841    // Find max
842    float local_max = -INFINITY;
843    for (uint s = tid; s < seq_len; s += 256) {
844        local_max = max(local_max, scores[s]);
845    }
846    local_max = simd_max(local_max);
847    threadgroup float shared_max[8];
848    if (simd_lane == 0) shared_max[simd_id] = local_max;
849    threadgroup_barrier(mem_flags::mem_threadgroup);
850    if (tid == 0) {
851        float m = shared_max[0];
852        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
853        shared_max[0] = m;
854    }
855    threadgroup_barrier(mem_flags::mem_threadgroup);
856    float max_val = shared_max[0];
857
858    // Exp and sum
859    float local_sum = 0.0;
860    for (uint s = tid; s < seq_len; s += 256) {
861        scores[s] = fast::exp(scores[s] - max_val);
862        local_sum += scores[s];
863    }
864    local_sum = simd_sum(local_sum);
865    threadgroup float shared_sum[8];
866    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
867    threadgroup_barrier(mem_flags::mem_threadgroup);
868    if (tid == 0) {
869        float total = 0.0;
870        for (uint i = 0; i < 8; i++) total += shared_sum[i];
871        shared_sum[0] = 1.0 / total;
872    }
873    threadgroup_barrier(mem_flags::mem_threadgroup);
874    float inv_sum = shared_sum[0];
875
876    for (uint s = tid; s < seq_len; s += 256) {
877        scores[s] *= inv_sum;
878    }
879    threadgroup_barrier(mem_flags::mem_threadgroup);
880
881    // Step 3: Weighted sum of V: output = scores · V, half4 vectorized.
882    // Mirrors attention_batch prefill: each thread handles 4 d-dims via
883    // half4 loads, with scalar fallback for head_dim not divisible by 4.
884    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
885    uint v_stride = num_kv_heads * head_dim;
886    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
887        uint d = d4 * 4;
888        float4 acc = float4(0.0);
889        uint v_base = kv_head * head_dim + d;
890        for (uint s = 0; s < seq_len4; s += 4) {
891            float sc0 = scores[s];
892            float sc1 = scores[s + 1];
893            float sc2 = scores[s + 2];
894            float sc3 = scores[s + 3];
895            acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
896                 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
897                 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
898                 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
899        }
900        for (uint s = seq_len4; s < seq_len; s++) {
901            acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
902        }
903        *(device float4*)(output + q_off + d) = acc;
904    }
905    // Scalar fallback for remaining dimensions (unused for current models)
906    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
907        float acc = 0.0;
908        uint v_base = kv_head * head_dim + d;
909        for (uint s = 0; s < seq_len; s++) {
910            acc += scores[s] * float(v_cache[s * v_stride + v_base]);
911        }
912        output[q_off + d] = acc;
913    }
914}
915
916// ── Batched prefill kernels ────────────────────────────────────────────
917// These kernels process M input vectors against the same weight matrix
918// in a single dispatch, converting mat-vec into mat-mat for better GPU
919// utilization during prompt prefill.
920
921// ── rms_norm_batch ─────────────────────────────────────────────────────
922// RMS normalization for a batch of vectors.
923// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
924// Grid: M threadgroups (one per token).
925kernel void rms_norm_batch(
926    device const float* input   [[buffer(0)]],  // [M, n]
927    device const float* weight  [[buffer(1)]],  // [n]
928    device float* output        [[buffer(2)]],  // [M, n]
929    constant uint& n            [[buffer(3)]],
930    constant float& eps         [[buffer(4)]],
931    constant uint& num_tokens   [[buffer(5)]],
932    uint tgid [[threadgroup_position_in_grid]],
933    uint tid [[thread_index_in_threadgroup]])
934{
935    if (tgid >= num_tokens) return;
936
937    uint base = tgid * n;
938
939    float sum_sq = 0.0f;
940    for (uint i = tid; i < n; i += 256) {
941        float v = input[base + i];
942        sum_sq += v * v;
943    }
944
945    sum_sq = simd_sum(sum_sq);
946
947    threadgroup float shared[8];
948    uint simd_id = tid / 32;
949    uint simd_lane = tid % 32;
950    if (simd_lane == 0) {
951        shared[simd_id] = sum_sq;
952    }
953    threadgroup_barrier(mem_flags::mem_threadgroup);
954
955    if (tid == 0) {
956        float total = 0.0f;
957        for (uint i = 0; i < 8; i++) {
958            total += shared[i];
959        }
960        shared[0] = fast::rsqrt(total / float(n) + eps);
961    }
962    threadgroup_barrier(mem_flags::mem_threadgroup);
963
964    float inv_rms = shared[0];
965
966    for (uint i = tid; i < n; i += 256) {
967        output[base + i] = input[base + i] * inv_rms * weight[i];
968    }
969}
970
971// ── rope_batch ─────────────────────────────────────────────────────────
972// Rotary Position Embedding for a batch of vectors with different positions.
973// data layout: [M, num_heads * head_dim], positions: [M]
974// Each thread handles one (token, head, pair) combination.
975kernel void rope_batch(
976    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
977    constant uint& num_heads     [[buffer(1)]],
978    constant uint& head_dim      [[buffer(2)]],
979    device const uint* positions  [[buffer(3)]],  // [M] position per token
980    constant float& theta        [[buffer(4)]],
981    constant uint& num_tokens    [[buffer(5)]],
982    uint id [[thread_position_in_grid]])
983{
984    uint half_dim = head_dim / 2;
985    uint pairs_per_token = num_heads * half_dim;
986    uint total = num_tokens * pairs_per_token;
987    if (id >= total) return;
988
989    uint token = id / pairs_per_token;
990    uint rem = id % pairs_per_token;
991    uint h = rem / half_dim;
992    uint i = rem % half_dim;
993    uint off = token * (num_heads * head_dim) + h * head_dim;
994
995    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
996    float angle = float(positions[token]) * freq;
997    float c = cos(angle);
998    float s = sin(angle);
999
1000    float x0 = data[off + 2 * i];
1001    float x1 = data[off + 2 * i + 1];
1002    data[off + 2 * i]     = x0 * c - x1 * s;
1003    data[off + 2 * i + 1] = x0 * s + x1 * c;
1004}
1005
1006// ── silu_mul_fused_batch ───────────────────────────────────────────────
1007// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
1008// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
1009kernel void silu_mul_fused_batch(
1010    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
1011    device float* output        [[buffer(1)]],  // [M, n]
1012    constant uint& n            [[buffer(2)]],
1013    constant uint& num_tokens   [[buffer(3)]],
1014    uint id [[thread_position_in_grid]])
1015{
1016    uint total = num_tokens * n;
1017    if (id >= total) return;
1018    uint token = id / n;
1019    uint i = id % n;
1020    uint gu_base = token * 2 * n;
1021    float g = gate_up[gu_base + i];
1022    float u = gate_up[gu_base + n + i];
1023    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
1024}
1025
1026// ── add_inplace_batch ──────────────────────────────────────────────────
1027// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
1028kernel void add_inplace_batch(
1029    device float* a        [[buffer(0)]],  // [M * n]
1030    device const float* b  [[buffer(1)]],  // [M * n]
1031    constant uint& total   [[buffer(2)]],  // M * n
1032    uint id [[thread_position_in_grid]])
1033{
1034    if (id >= total) return;
1035    a[id] += b[id];
1036}
1037
1038// ── copy_embedding_batch ───────────────────────────────────────────────
1039// Copy M embedding rows from embedding table to a contiguous batch buffer.
1040// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1041kernel void copy_embedding_batch(
1042    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1043    device float* output        [[buffer(1)]],  // [M, dim]
1044    device const uint* tokens   [[buffer(2)]],  // [M]
1045    constant uint& dim          [[buffer(3)]],
1046    constant uint& num_tokens   [[buffer(4)]],
1047    uint id [[thread_position_in_grid]])
1048{
1049    uint total = num_tokens * dim;
1050    if (id >= total) return;
1051    uint token_idx = id / dim;
1052    uint d = id % dim;
1053    output[id] = embed[tokens[token_idx] * dim + d];
1054}
1055
1056// ── matmul_vec_batch ───────────────────────────────────────────────────
1057// Batched matrix-vector multiply: process M input vectors against the same
1058// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1059// Each threadgroup handles one (token, row_group) pair.
1060kernel void matmul_vec_batch(
1061    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1062    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1063    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1064    constant uint& num_tokens   [[buffer(3)]],  // M
1065    constant uint& rows         [[buffer(4)]],
1066    constant uint& cols         [[buffer(5)]],
1067    uint tgid [[threadgroup_position_in_grid]],
1068    uint tid [[thread_index_in_threadgroup]],
1069    uint simd_lane [[thread_index_in_simdgroup]],
1070    uint simd_id [[simdgroup_index_in_threadgroup]])
1071{
1072    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1073    uint token = tgid / row_tgs;
1074    uint tg_in_token = tgid % row_tgs;
1075    if (token >= num_tokens) return;
1076
1077    // Load this token's input vector into shared memory
1078    threadgroup float vec_tile[VEC_TILE_SIZE];
1079    device const float* input = inputs + token * cols;
1080    for (uint i = tid; i < cols; i += 256) {
1081        vec_tile[i] = input[i];
1082    }
1083    threadgroup_barrier(mem_flags::mem_threadgroup);
1084
1085    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1086    if (row_base >= rows) return;
1087
1088    uint base0 = row_base * cols;
1089    uint base1 = (row_base + 1) * cols;
1090    uint base2 = (row_base + 2) * cols;
1091    uint base3 = (row_base + 3) * cols;
1092    uint base4 = (row_base + 4) * cols;
1093    uint base5 = (row_base + 5) * cols;
1094    uint base6 = (row_base + 6) * cols;
1095    uint base7 = (row_base + 7) * cols;
1096
1097    uint cols_vec4 = cols & ~127u;
1098    float4 sum4_0 = float4(0.0f);
1099    float4 sum4_1 = float4(0.0f);
1100    float4 sum4_2 = float4(0.0f);
1101    float4 sum4_3 = float4(0.0f);
1102    float4 sum4_4 = float4(0.0f);
1103    float4 sum4_5 = float4(0.0f);
1104    float4 sum4_6 = float4(0.0f);
1105    float4 sum4_7 = float4(0.0f);
1106
1107    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1108        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1109        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1110        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1111        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1112        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1113        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1114        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1115        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1116        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1117    }
1118
1119    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1120    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1121    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1122    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1123    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1124    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1125    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1126    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1127
1128    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1129        float vv = vec_tile[j];
1130        sum0 += matrix[base0 + j] * vv;
1131        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1132        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1133        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1134        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1135        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1136        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1137        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1138    }
1139
1140    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1141    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1142    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1143    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1144
1145    device float* output = outputs + token * rows;
1146    if (simd_lane == 0) {
1147        if (row_base     < rows) output[row_base]     = sum0;
1148        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1149        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1150        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1151        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1152        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1153        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1154        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1155    }
1156}
1157
1158// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1159// Batched Q8_0 matrix-vector multiply for M input vectors.
1160// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1161kernel void matmul_vec_q8_batch(
1162    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1163    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1164    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1165    constant uint& num_tokens    [[buffer(3)]],  // M
1166    constant uint& rows          [[buffer(4)]],
1167    constant uint& cols          [[buffer(5)]],
1168    uint tgid [[threadgroup_position_in_grid]],
1169    uint tid [[thread_index_in_threadgroup]],
1170    uint simd_lane [[thread_index_in_simdgroup]],
1171    uint simd_id [[simdgroup_index_in_threadgroup]])
1172{
1173    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1174    uint token = tgid / row_tgs;
1175    uint tg_in_token = tgid % row_tgs;
1176    if (token >= num_tokens) return;
1177
1178    // Load this token's input vector into shared memory
1179    threadgroup float vec_tile[VEC_TILE_SIZE];
1180    device const float* input = inputs + token * cols;
1181    for (uint i = tid; i < cols; i += 256) {
1182        vec_tile[i] = input[i];
1183    }
1184    threadgroup_barrier(mem_flags::mem_threadgroup);
1185
1186    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1187    if (row_base >= rows) return;
1188
1189    uint blocks_per_row = cols / 32;
1190    uint row_bytes = blocks_per_row * 34;
1191
1192    device const uchar* r0 = matrix + row_base * row_bytes;
1193    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1194    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1195    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1196
1197    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1198
1199    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1200        uint bb = blk * 34;
1201        uint vb = blk * 32;
1202
1203        float sc0 = float(*(device const half*)(r0 + bb));
1204        float sc1 = float(*(device const half*)(r1 + bb));
1205        float sc2 = float(*(device const half*)(r2 + bb));
1206        float sc3 = float(*(device const half*)(r3 + bb));
1207
1208        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1209        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1210        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1211        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1212        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1213        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1214
1215        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1216        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1217        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1218        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1219        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1220        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1221        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1222        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1223
1224        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1225            short4 _s = short4(SHORT4); \
1226            char2 _a = as_type<char2>(_s.x); \
1227            char2 _b = as_type<char2>(_s.y); \
1228            char2 _c = as_type<char2>(_s.z); \
1229            char2 _d = as_type<char2>(_s.w); \
1230            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1231            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1232        }
1233
1234        float4 f0, f1;
1235        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1236
1237        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1238        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1239        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1240        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1241
1242        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1243        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1244        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1245        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1246
1247        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1248        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1249        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1250        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1251
1252        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1253        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1254        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1255        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1256
1257        #undef Q8_UNPACK8
1258
1259        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1260    }
1261
1262    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1263    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1264
1265    device float* output = outputs + token * rows;
1266    if (simd_lane == 0) {
1267        if (row_base     < rows) output[row_base]     = sum0;
1268        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1269        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1270        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1271    }
1272}
1273
1274// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1275// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1276// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1277// the Q8_0 weight blocks are fetched once from device memory and reused for
1278// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1279// per-token dispatch).
1280//
1281// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1282// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1283// with simd_sum.  Token vectors are read directly from device memory inside
1284// the block loop (not cached in shared memory) so intermediate_size up to
1285// 8192 fits without spilling threadgroup memory.
1286constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1287
1288kernel void matmul_q8_gemm_batch(
1289    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1290    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1291    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1292    constant uint& num_tokens    [[buffer(3)]],  // M
1293    constant uint& rows          [[buffer(4)]],
1294    constant uint& cols          [[buffer(5)]],
1295    uint2 tgid [[threadgroup_position_in_grid]],
1296    uint tid [[thread_index_in_threadgroup]],
1297    uint simd_lane [[thread_index_in_simdgroup]],
1298    uint simd_id [[simdgroup_index_in_threadgroup]])
1299{
1300    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1301    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1302    if (row_base >= rows || tok_base >= num_tokens) return;
1303
1304    // How many tokens in this tile are valid?
1305    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1306
1307    uint blocks_per_row = cols / 32;
1308    uint row_bytes = blocks_per_row * 34;
1309
1310    device const uchar* r0 = matrix + row_base * row_bytes;
1311    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1312    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1313    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1314
1315    // Accumulators: 4 tokens × 4 rows per simdgroup.
1316    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1317    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1318    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1319    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1320
1321    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1322        uint bb = blk * 34;
1323        uint vb = blk * 32;
1324
1325        // ── Load weight data ONCE per block (reused across all tokens) ──
1326        float sc0 = float(*(device const half*)(r0 + bb));
1327        float sc1 = float(*(device const half*)(r1 + bb));
1328        float sc2 = float(*(device const half*)(r2 + bb));
1329        float sc3 = float(*(device const half*)(r3 + bb));
1330
1331        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1332        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1333        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1334        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1335
1336        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1337            short4 _s = short4(SHORT4); \
1338            char2 _a = as_type<char2>(_s.x); \
1339            char2 _b = as_type<char2>(_s.y); \
1340            char2 _c = as_type<char2>(_s.z); \
1341            char2 _d = as_type<char2>(_s.w); \
1342            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1343            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1344        }
1345
1346        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1347        // registers for the duration of the block and are dotted against
1348        // every token's vector tile.
1349        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1350        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1351        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1352        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1353
1354        Q8_UNPACK8(d0[0], w0_0, w0_1);
1355        Q8_UNPACK8(d0[1], w0_2, w0_3);
1356        Q8_UNPACK8(d0[2], w0_4, w0_5);
1357        Q8_UNPACK8(d0[3], w0_6, w0_7);
1358
1359        Q8_UNPACK8(d1[0], w1_0, w1_1);
1360        Q8_UNPACK8(d1[1], w1_2, w1_3);
1361        Q8_UNPACK8(d1[2], w1_4, w1_5);
1362        Q8_UNPACK8(d1[3], w1_6, w1_7);
1363
1364        Q8_UNPACK8(d2[0], w2_0, w2_1);
1365        Q8_UNPACK8(d2[1], w2_2, w2_3);
1366        Q8_UNPACK8(d2[2], w2_4, w2_5);
1367        Q8_UNPACK8(d2[3], w2_6, w2_7);
1368
1369        Q8_UNPACK8(d3[0], w3_0, w3_1);
1370        Q8_UNPACK8(d3[1], w3_2, w3_3);
1371        Q8_UNPACK8(d3[2], w3_4, w3_5);
1372        Q8_UNPACK8(d3[3], w3_6, w3_7);
1373
1374        #undef Q8_UNPACK8
1375
1376        // ── For each token, read vector and accumulate against shared weights ──
1377        // Token 0 (always valid: tok_count >= 1).
1378        {
1379            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1380            float4 v0 = *(device const float4*)(a0);
1381            float4 v1 = *(device const float4*)(a0 + 4);
1382            float4 v2 = *(device const float4*)(a0 + 8);
1383            float4 v3 = *(device const float4*)(a0 + 12);
1384            float4 v4 = *(device const float4*)(a0 + 16);
1385            float4 v5 = *(device const float4*)(a0 + 20);
1386            float4 v6 = *(device const float4*)(a0 + 24);
1387            float4 v7 = *(device const float4*)(a0 + 28);
1388            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1389                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1390            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1391                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1392            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1393                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1394            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1395                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1396            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1397        }
1398        // Token 1
1399        if (tok_count > 1) {
1400            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1401            float4 v0 = *(device const float4*)(a1);
1402            float4 v1 = *(device const float4*)(a1 + 4);
1403            float4 v2 = *(device const float4*)(a1 + 8);
1404            float4 v3 = *(device const float4*)(a1 + 12);
1405            float4 v4 = *(device const float4*)(a1 + 16);
1406            float4 v5 = *(device const float4*)(a1 + 20);
1407            float4 v6 = *(device const float4*)(a1 + 24);
1408            float4 v7 = *(device const float4*)(a1 + 28);
1409            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1410                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1411            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1412                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1413            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1414                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1415            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1416                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1417            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1418        }
1419        // Token 2
1420        if (tok_count > 2) {
1421            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1422            float4 v0 = *(device const float4*)(a2);
1423            float4 v1 = *(device const float4*)(a2 + 4);
1424            float4 v2 = *(device const float4*)(a2 + 8);
1425            float4 v3 = *(device const float4*)(a2 + 12);
1426            float4 v4 = *(device const float4*)(a2 + 16);
1427            float4 v5 = *(device const float4*)(a2 + 20);
1428            float4 v6 = *(device const float4*)(a2 + 24);
1429            float4 v7 = *(device const float4*)(a2 + 28);
1430            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1431                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1432            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1433                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1434            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1435                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1436            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1437                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1438            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1439        }
1440        // Token 3
1441        if (tok_count > 3) {
1442            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1443            float4 v0 = *(device const float4*)(a3);
1444            float4 v1 = *(device const float4*)(a3 + 4);
1445            float4 v2 = *(device const float4*)(a3 + 8);
1446            float4 v3 = *(device const float4*)(a3 + 12);
1447            float4 v4 = *(device const float4*)(a3 + 16);
1448            float4 v5 = *(device const float4*)(a3 + 20);
1449            float4 v6 = *(device const float4*)(a3 + 24);
1450            float4 v7 = *(device const float4*)(a3 + 28);
1451            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1452                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1453            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1454                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1455            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1456                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1457            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1458                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1459            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1460        }
1461    }
1462
1463    // simdgroup reduction
1464    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1465    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1466    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1467    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1468
1469    if (simd_lane == 0) {
1470        device float* o0 = outputs + (tok_base + 0) * rows;
1471        if (row_base     < rows) o0[row_base]     = s00;
1472        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1473        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1474        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1475
1476        if (tok_count > 1) {
1477            device float* o1 = outputs + (tok_base + 1) * rows;
1478            if (row_base     < rows) o1[row_base]     = s10;
1479            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1480            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1481            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1482        }
1483        if (tok_count > 2) {
1484            device float* o2 = outputs + (tok_base + 2) * rows;
1485            if (row_base     < rows) o2[row_base]     = s20;
1486            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1487            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1488            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1489        }
1490        if (tok_count > 3) {
1491            device float* o3 = outputs + (tok_base + 3) * rows;
1492            if (row_base     < rows) o3[row_base]     = s30;
1493            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1494            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1495            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1496        }
1497    }
1498}
1499
1500// ── matmul_q8_mma ──────────────────────────────────────────────────────
1501// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1502// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1503// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1504// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1505//
1506// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1507// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1508// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1509// dequantized into threadgroup memory once per block and reused by all
1510// simdgroups in the tile.
1511//
1512// Assumptions (verified in the dispatch helper, falls back otherwise):
1513//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1514//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1515//   * num_tokens may be any value; partial row at the tile boundary is handled
1516//     via a scratch copy path.
1517constant constexpr uint MMA_TOK_TILE = 16;
1518constant constexpr uint MMA_ROW_TILE = 16;
1519
1520kernel void matmul_q8_mma(
1521    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1522    device const float* inputs   [[buffer(1)]],  // [M, cols]
1523    device float* outputs        [[buffer(2)]],  // [M, rows]
1524    constant uint& num_tokens    [[buffer(3)]],
1525    constant uint& rows          [[buffer(4)]],
1526    constant uint& cols          [[buffer(5)]],
1527    uint2 tgid [[threadgroup_position_in_grid]],
1528    uint tid [[thread_index_in_threadgroup]],
1529    uint simd_id [[simdgroup_index_in_threadgroup]])
1530{
1531    uint row_base = tgid.x * MMA_ROW_TILE;
1532    uint tok_base = tgid.y * MMA_TOK_TILE;
1533    if (row_base >= rows || tok_base >= num_tokens) return;
1534
1535    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1536    threadgroup float w_tile[MMA_ROW_TILE * 32];
1537    threadgroup float t_tile[MMA_TOK_TILE * 32];
1538
1539    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1540    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1541    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1542
1543    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1544
1545    uint blocks_per_row = cols / 32;
1546    uint row_bytes = blocks_per_row * 34;
1547
1548    for (uint blk = 0; blk < blocks_per_row; blk++) {
1549        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1550        // 512 floats / 128 threads = 4 floats per thread.
1551        {
1552            uint base = tid * 4;
1553            for (uint ii = 0; ii < 4; ii++) {
1554                uint idx = base + ii;
1555                uint r = idx / 32;
1556                uint k = idx % 32;
1557                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1558                float sc = float(*(device const half*)rp);
1559                int ival = int(*(device const int8_t*)(rp + 2 + k));
1560                w_tile[r * 32 + k] = float(ival) * sc;
1561            }
1562        }
1563
1564        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1565        {
1566            uint base = tid * 4;
1567            for (uint ii = 0; ii < 4; ii++) {
1568                uint idx = base + ii;
1569                uint m = idx / 32;
1570                uint k = idx % 32;
1571                uint tok = tok_base + m;
1572                t_tile[m * 32 + k] = (tok < num_tokens)
1573                    ? inputs[tok * cols + blk * 32 + k]
1574                    : 0.0f;
1575            }
1576        }
1577
1578        threadgroup_barrier(mem_flags::mem_threadgroup);
1579
1580        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1581        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1582        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1583        // C[m, r] += A[m, k] * B[k, r]
1584        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1585            simdgroup_matrix<float, 8, 8> A, B;
1586            simdgroup_load(A,
1587                t_tile + sg_tok_base * 32 + k_sub * 8,
1588                32,
1589                ulong2(0, 0),
1590                false);
1591            simdgroup_load(B,
1592                w_tile + sg_row_base * 32 + k_sub * 8,
1593                32,
1594                ulong2(0, 0),
1595                true);
1596            simdgroup_multiply_accumulate(C, A, B, C);
1597        }
1598
1599        threadgroup_barrier(mem_flags::mem_threadgroup);
1600    }
1601
1602    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1603    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1604    uint out_tok = tok_base + sg_tok_base;
1605    uint out_row = row_base + sg_row_base;
1606    bool full_tok = (out_tok + 8 <= num_tokens);
1607    if (full_tok) {
1608        // Fast path: entire 8×8 sub-tile is in-bounds.
1609        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1610    } else if (out_tok < num_tokens) {
1611        // Partial row at the last token tile: stage in per-simdgroup scratch
1612        // and scalar-copy the valid rows.
1613        threadgroup float scratch[4 * 64];
1614        simdgroup_store(C, scratch + simd_id * 64, 8);
1615        simdgroup_barrier(mem_flags::mem_threadgroup);
1616        uint lane = tid % 32;
1617        if (lane == 0) {
1618            uint valid = num_tokens - out_tok;  // 1..7
1619            for (uint m = 0; m < valid; m++) {
1620                device float* dst = outputs + (out_tok + m) * rows + out_row;
1621                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1622                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1623                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1624            }
1625        }
1626    }
1627}
1628
1629// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1630// Larger-tile variant of matmul_q8_mma for long-context prefill.
1631//
1632// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1633// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1634// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1635//
1636//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1637//     output sub-tiles (tok, row):
1638//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1639//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1640//
1641// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1642// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1643// kernel — and halves the number of threadgroups vs the 16×16 tile.
1644//
1645// Assumptions (verified in dispatch helper, fallback otherwise):
1646//   * cols % 32 == 0
1647//   * rows % 32 == 0
1648constant constexpr uint MMA32_TOK_TILE = 32;
1649constant constexpr uint MMA32_ROW_TILE = 32;
1650
1651kernel void matmul_q8_mma32(
1652    device const uchar* matrix   [[buffer(0)]],
1653    device const float* inputs   [[buffer(1)]],
1654    device float* outputs        [[buffer(2)]],
1655    constant uint& num_tokens    [[buffer(3)]],
1656    constant uint& rows          [[buffer(4)]],
1657    constant uint& cols          [[buffer(5)]],
1658    uint2 tgid [[threadgroup_position_in_grid]],
1659    uint tid [[thread_index_in_threadgroup]],
1660    uint simd_id [[simdgroup_index_in_threadgroup]])
1661{
1662    uint row_base = tgid.x * MMA32_ROW_TILE;
1663    uint tok_base = tgid.y * MMA32_TOK_TILE;
1664    if (row_base >= rows || tok_base >= num_tokens) return;
1665
1666    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1667    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1668    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1669
1670    uint sg_tok_idx  = simd_id / 2;      // 0..3
1671    uint sg_row_half = simd_id % 2;      // 0..1
1672    uint sg_tok_base = sg_tok_idx * 8;
1673    uint sg_row_base_a = sg_row_half * 16 + 0;
1674    uint sg_row_base_b = sg_row_half * 16 + 8;
1675
1676    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1677    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1678
1679    uint blocks_per_row = cols / 32;
1680    uint row_bytes = blocks_per_row * 34;
1681
1682    for (uint blk = 0; blk < blocks_per_row; blk++) {
1683        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1684        {
1685            uint base = tid * 4;
1686            for (uint ii = 0; ii < 4; ii++) {
1687                uint idx = base + ii;
1688                uint r = idx / 32;
1689                uint k = idx % 32;
1690                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1691                float sc = float(*(device const half*)rp);
1692                int ival = int(*(device const int8_t*)(rp + 2 + k));
1693                w_tile[r * 32 + k] = float(ival) * sc;
1694            }
1695        }
1696
1697        // Cooperative token tile load.
1698        {
1699            uint base = tid * 4;
1700            for (uint ii = 0; ii < 4; ii++) {
1701                uint idx = base + ii;
1702                uint m = idx / 32;
1703                uint k = idx % 32;
1704                uint tok = tok_base + m;
1705                t_tile[m * 32 + k] = (tok < num_tokens)
1706                    ? inputs[tok * cols + blk * 32 + k]
1707                    : 0.0f;
1708            }
1709        }
1710
1711        threadgroup_barrier(mem_flags::mem_threadgroup);
1712
1713        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1714        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1715            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1716            simdgroup_load(A,
1717                t_tile + sg_tok_base * 32 + k_sub * 8,
1718                32,
1719                ulong2(0, 0),
1720                false);
1721            simdgroup_load(B_a,
1722                w_tile + sg_row_base_a * 32 + k_sub * 8,
1723                32,
1724                ulong2(0, 0),
1725                true);
1726            simdgroup_load(B_b,
1727                w_tile + sg_row_base_b * 32 + k_sub * 8,
1728                32,
1729                ulong2(0, 0),
1730                true);
1731            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1732            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1733        }
1734
1735        threadgroup_barrier(mem_flags::mem_threadgroup);
1736    }
1737
1738    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1739    // (verified in dispatch), so full simdgroup_store is safe for the row
1740    // dimension; only the last token tile may be partial.
1741    uint out_tok = tok_base + sg_tok_base;
1742    uint out_row_a = row_base + sg_row_base_a;
1743    uint out_row_b = row_base + sg_row_base_b;
1744    bool full_tok = (out_tok + 8 <= num_tokens);
1745    if (full_tok) {
1746        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1747        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1748    } else if (out_tok < num_tokens) {
1749        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1750        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1751        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1752        simdgroup_barrier(mem_flags::mem_threadgroup);
1753        uint lane = tid % 32;
1754        if (lane == 0) {
1755            uint valid = num_tokens - out_tok;  // 1..7
1756            for (uint m = 0; m < valid; m++) {
1757                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1758                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1759                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1760                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1761                for (uint j = 0; j < 8; j++) {
1762                    dst_a[j] = src_a[j];
1763                    dst_b[j] = src_b[j];
1764                }
1765            }
1766        }
1767    }
1768}
1769
1770// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1771// FP16 threadgroup-tile variant of matmul_q8_mma32.
1772//
1773// Stores dequantized weights and token inputs as `half` in threadgroup
1774// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1775// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1776// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1777// representation preserves the full quantized dynamic range.  Token
1778// activations stay numerically safe because the subsequent
1779// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1780//
1781// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1782// accumulators each.  Primary win vs mma32 is occupancy at moderate
1783// prefill lengths where the GPU is wave-starved.
1784kernel void matmul_q8_mma32_h(
1785    device const uchar* matrix   [[buffer(0)]],
1786    device const float* inputs   [[buffer(1)]],
1787    device float* outputs        [[buffer(2)]],
1788    constant uint& num_tokens    [[buffer(3)]],
1789    constant uint& rows          [[buffer(4)]],
1790    constant uint& cols          [[buffer(5)]],
1791    uint2 tgid [[threadgroup_position_in_grid]],
1792    uint tid [[thread_index_in_threadgroup]],
1793    uint simd_id [[simdgroup_index_in_threadgroup]])
1794{
1795    uint row_base = tgid.x * MMA32_ROW_TILE;
1796    uint tok_base = tgid.y * MMA32_TOK_TILE;
1797    if (row_base >= rows || tok_base >= num_tokens) return;
1798
1799    // 32×32 half tiles — 2 KB each, 4 KB total.
1800    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1801    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1802
1803    uint sg_tok_idx  = simd_id / 2;
1804    uint sg_row_half = simd_id % 2;
1805    uint sg_tok_base = sg_tok_idx * 8;
1806    uint sg_row_base_a = sg_row_half * 16 + 0;
1807    uint sg_row_base_b = sg_row_half * 16 + 8;
1808
1809    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1810    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1811
1812    uint blocks_per_row = cols / 32;
1813    uint row_bytes = blocks_per_row * 34;
1814
1815    for (uint blk = 0; blk < blocks_per_row; blk++) {
1816        // Cooperative weight dequantization to FP16.
1817        {
1818            uint base = tid * 4;
1819            for (uint ii = 0; ii < 4; ii++) {
1820                uint idx = base + ii;
1821                uint r = idx / 32;
1822                uint k = idx % 32;
1823                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1824                float sc = float(*(device const half*)rp);
1825                int ival = int(*(device const int8_t*)(rp + 2 + k));
1826                w_tile[r * 32 + k] = half(float(ival) * sc);
1827            }
1828        }
1829
1830        // Cooperative token tile load (f32 → f16 narrowing).
1831        {
1832            uint base = tid * 4;
1833            for (uint ii = 0; ii < 4; ii++) {
1834                uint idx = base + ii;
1835                uint m = idx / 32;
1836                uint k = idx % 32;
1837                uint tok = tok_base + m;
1838                t_tile[m * 32 + k] = (tok < num_tokens)
1839                    ? half(inputs[tok * cols + blk * 32 + k])
1840                    : half(0);
1841            }
1842        }
1843
1844        threadgroup_barrier(mem_flags::mem_threadgroup);
1845
1846        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1847            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1848            simdgroup_load(A,
1849                t_tile + sg_tok_base * 32 + k_sub * 8,
1850                32,
1851                ulong2(0, 0),
1852                false);
1853            simdgroup_load(B_a,
1854                w_tile + sg_row_base_a * 32 + k_sub * 8,
1855                32,
1856                ulong2(0, 0),
1857                true);
1858            simdgroup_load(B_b,
1859                w_tile + sg_row_base_b * 32 + k_sub * 8,
1860                32,
1861                ulong2(0, 0),
1862                true);
1863            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1864            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1865        }
1866
1867        threadgroup_barrier(mem_flags::mem_threadgroup);
1868    }
1869
1870    uint out_tok = tok_base + sg_tok_base;
1871    uint out_row_a = row_base + sg_row_base_a;
1872    uint out_row_b = row_base + sg_row_base_b;
1873    bool full_tok = (out_tok + 8 <= num_tokens);
1874    if (full_tok) {
1875        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1876        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1877    } else if (out_tok < num_tokens) {
1878        threadgroup float scratch[8 * 2 * 64];
1879        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1880        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1881        simdgroup_barrier(mem_flags::mem_threadgroup);
1882        uint lane = tid % 32;
1883        if (lane == 0) {
1884            uint valid = num_tokens - out_tok;
1885            for (uint m = 0; m < valid; m++) {
1886                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1887                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1888                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1889                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1890                for (uint j = 0; j < 8; j++) {
1891                    dst_a[j] = src_a[j];
1892                    dst_b[j] = src_b[j];
1893                }
1894            }
1895        }
1896    }
1897}
1898
1899// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1900// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1901//
1902// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1903// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1904//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1905//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1906// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1907//
1908// Per K_sub iteration: load two A fragments and two B fragments, then run
1909// **four** MMA instructions reusing A_top with both B's and A_bot with
1910// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1911// 2-accumulator kernel and halves the thread count per threadgroup (128
1912// threads), which often improves occupancy on Apple GPUs where the
1913// concurrent-thread budget is the tighter limit than shared-memory size.
1914kernel void matmul_q8_mma32_h4(
1915    device const uchar* matrix   [[buffer(0)]],
1916    device const float* inputs   [[buffer(1)]],
1917    device float* outputs        [[buffer(2)]],
1918    constant uint& num_tokens    [[buffer(3)]],
1919    constant uint& rows          [[buffer(4)]],
1920    constant uint& cols          [[buffer(5)]],
1921    uint2 tgid [[threadgroup_position_in_grid]],
1922    uint tid [[thread_index_in_threadgroup]],
1923    uint simd_id [[simdgroup_index_in_threadgroup]])
1924{
1925    uint row_base = tgid.x * MMA32_ROW_TILE;
1926    uint tok_base = tgid.y * MMA32_TOK_TILE;
1927    if (row_base >= rows || tok_base >= num_tokens) return;
1928
1929    // 32×32 FP16 tiles, 4 KB total.
1930    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1931    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1932
1933    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1934    uint sg_tok_q = simd_id / 2;   // 0..1
1935    uint sg_row_q = simd_id % 2;   // 0..1
1936    uint sg_tok_base = sg_tok_q * 16;
1937    uint sg_row_base = sg_row_q * 16;
1938
1939    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1940    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1941    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1942    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1943
1944    uint blocks_per_row = cols / 32;
1945    uint row_bytes = blocks_per_row * 34;
1946
1947    for (uint blk = 0; blk < blocks_per_row; blk++) {
1948        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1949        {
1950            uint base = tid * 8;
1951            for (uint ii = 0; ii < 8; ii++) {
1952                uint idx = base + ii;
1953                uint r = idx / 32;
1954                uint k = idx % 32;
1955                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1956                float sc = float(*(device const half*)rp);
1957                int ival = int(*(device const int8_t*)(rp + 2 + k));
1958                w_tile[r * 32 + k] = half(float(ival) * sc);
1959            }
1960        }
1961
1962        // Cooperative token tile load.
1963        {
1964            uint base = tid * 8;
1965            for (uint ii = 0; ii < 8; ii++) {
1966                uint idx = base + ii;
1967                uint m = idx / 32;
1968                uint k = idx % 32;
1969                uint tok = tok_base + m;
1970                t_tile[m * 32 + k] = (tok < num_tokens)
1971                    ? half(inputs[tok * cols + blk * 32 + k])
1972                    : half(0);
1973            }
1974        }
1975
1976        threadgroup_barrier(mem_flags::mem_threadgroup);
1977
1978        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1979        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1980            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1981            simdgroup_load(A_top,
1982                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1983                32,
1984                ulong2(0, 0),
1985                false);
1986            simdgroup_load(A_bot,
1987                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1988                32,
1989                ulong2(0, 0),
1990                false);
1991            simdgroup_load(B_lo,
1992                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1993                32,
1994                ulong2(0, 0),
1995                true);
1996            simdgroup_load(B_hi,
1997                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1998                32,
1999                ulong2(0, 0),
2000                true);
2001            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2002            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2003            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2004            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2005        }
2006
2007        threadgroup_barrier(mem_flags::mem_threadgroup);
2008    }
2009
2010    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
2011    uint out_tok_top = tok_base + sg_tok_base + 0;
2012    uint out_tok_bot = tok_base + sg_tok_base + 8;
2013    uint out_row_lo  = row_base + sg_row_base + 0;
2014    uint out_row_hi  = row_base + sg_row_base + 8;
2015    bool full = (out_tok_bot + 8 <= num_tokens);
2016    if (full) {
2017        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
2018        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
2019        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
2020        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
2021    } else {
2022        // Partial-token fallback via per-simdgroup scratch.
2023        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
2024        uint sg_base = simd_id * 256;
2025        simdgroup_store(C_00, scratch + sg_base +   0, 8);
2026        simdgroup_store(C_01, scratch + sg_base +  64, 8);
2027        simdgroup_store(C_10, scratch + sg_base + 128, 8);
2028        simdgroup_store(C_11, scratch + sg_base + 192, 8);
2029        simdgroup_barrier(mem_flags::mem_threadgroup);
2030        uint lane = tid % 32;
2031        if (lane == 0) {
2032            for (uint m = 0; m < 8; m++) {
2033                uint t_top = out_tok_top + m;
2034                if (t_top < num_tokens) {
2035                    device float* dst0 = outputs + t_top * rows + out_row_lo;
2036                    device float* dst1 = outputs + t_top * rows + out_row_hi;
2037                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
2038                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
2039                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2040                }
2041                uint t_bot = out_tok_bot + m;
2042                if (t_bot < num_tokens) {
2043                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2044                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2045                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2046                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2047                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2048                }
2049            }
2050        }
2051    }
2052}
2053
2054// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2055// All-half MMA variant of matmul_q8_mma32_h4.
2056//
2057// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2058// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2059// rate (dual-issue FMA), so if Q8_0 precision holds through half
2060// accumulation this kernel can double the effective FLOP throughput on
2061// matmul-bound prefill.
2062//
2063// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2064// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2065// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2066// precision on extreme values, but the inputs are already quantized so the
2067// per-product error floor is higher than the half-precision rounding error.
2068// We verify correctness on 135M / 1B / 3B before enabling.
2069kernel void matmul_q8_mma32_hh4(
2070    device const uchar* matrix   [[buffer(0)]],
2071    device const float* inputs   [[buffer(1)]],
2072    device float* outputs        [[buffer(2)]],
2073    constant uint& num_tokens    [[buffer(3)]],
2074    constant uint& rows          [[buffer(4)]],
2075    constant uint& cols          [[buffer(5)]],
2076    uint2 tgid [[threadgroup_position_in_grid]],
2077    uint tid [[thread_index_in_threadgroup]],
2078    uint simd_id [[simdgroup_index_in_threadgroup]])
2079{
2080    uint row_base = tgid.x * MMA32_ROW_TILE;
2081    uint tok_base = tgid.y * MMA32_TOK_TILE;
2082    if (row_base >= rows || tok_base >= num_tokens) return;
2083
2084    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2085    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2086
2087    uint sg_tok_q = simd_id / 2;
2088    uint sg_row_q = simd_id % 2;
2089    uint sg_tok_base = sg_tok_q * 16;
2090    uint sg_row_base = sg_row_q * 16;
2091
2092    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2093    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2094    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2095    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2096
2097    uint blocks_per_row = cols / 32;
2098    uint row_bytes = blocks_per_row * 34;
2099
2100    for (uint blk = 0; blk < blocks_per_row; blk++) {
2101        {
2102            uint base = tid * 8;
2103            for (uint ii = 0; ii < 8; ii++) {
2104                uint idx = base + ii;
2105                uint r = idx / 32;
2106                uint k = idx % 32;
2107                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2108                float sc = float(*(device const half*)rp);
2109                int ival = int(*(device const int8_t*)(rp + 2 + k));
2110                w_tile[r * 32 + k] = half(float(ival) * sc);
2111            }
2112        }
2113        {
2114            uint base = tid * 8;
2115            for (uint ii = 0; ii < 8; ii++) {
2116                uint idx = base + ii;
2117                uint m = idx / 32;
2118                uint k = idx % 32;
2119                uint tok = tok_base + m;
2120                t_tile[m * 32 + k] = (tok < num_tokens)
2121                    ? half(inputs[tok * cols + blk * 32 + k])
2122                    : half(0);
2123            }
2124        }
2125        threadgroup_barrier(mem_flags::mem_threadgroup);
2126
2127        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2128            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2129            simdgroup_load(A_top,
2130                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2131                32, ulong2(0, 0), false);
2132            simdgroup_load(A_bot,
2133                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2134                32, ulong2(0, 0), false);
2135            simdgroup_load(B_lo,
2136                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2137                32, ulong2(0, 0), true);
2138            simdgroup_load(B_hi,
2139                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2140                32, ulong2(0, 0), true);
2141            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2142            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2143            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2144            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2145        }
2146
2147        threadgroup_barrier(mem_flags::mem_threadgroup);
2148    }
2149
2150    // Store half accumulators via scratch (must widen to f32 for device output).
2151    uint out_tok_top = tok_base + sg_tok_base + 0;
2152    uint out_tok_bot = tok_base + sg_tok_base + 8;
2153    uint out_row_lo  = row_base + sg_row_base + 0;
2154    uint out_row_hi  = row_base + sg_row_base + 8;
2155
2156    threadgroup half scratch[4 * 4 * 64];
2157    uint sg_base = simd_id * 256;
2158    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2159    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2160    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2161    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2162    simdgroup_barrier(mem_flags::mem_threadgroup);
2163    uint lane = tid % 32;
2164    if (lane == 0) {
2165        for (uint m = 0; m < 8; m++) {
2166            uint t_top = out_tok_top + m;
2167            if (t_top < num_tokens) {
2168                device float* dst0 = outputs + t_top * rows + out_row_lo;
2169                device float* dst1 = outputs + t_top * rows + out_row_hi;
2170                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2171                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2172                for (uint j = 0; j < 8; j++) {
2173                    dst0[j] = float(src0[j]);
2174                    dst1[j] = float(src1[j]);
2175                }
2176            }
2177            uint t_bot = out_tok_bot + m;
2178            if (t_bot < num_tokens) {
2179                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2180                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2181                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2182                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2183                for (uint j = 0; j < 8; j++) {
2184                    dst2[j] = float(src2[j]);
2185                    dst3[j] = float(src3[j]);
2186                }
2187            }
2188        }
2189    }
2190}
2191
2192// ── add_bias_batch ─────────────────────────────────────────────────────
2193// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2194// Used for Qwen2 QKV bias after the fused qkv matmul.
2195//     out[token, i] += bias[i]    for i in 0..rows, token in 0..num_tokens
2196kernel void add_bias_batch(
2197    device float* out            [[buffer(0)]],  // [num_tokens, rows]
2198    device const float* bias     [[buffer(1)]],  // [rows]
2199    constant uint& num_tokens    [[buffer(2)]],
2200    constant uint& rows          [[buffer(3)]],
2201    uint id [[thread_position_in_grid]])
2202{
2203    uint total = num_tokens * rows;
2204    if (id >= total) return;
2205    uint i = id % rows;
2206    out[id] += bias[i];
2207}
2208
2209// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2210// Batched Q4_0 matrix-vector multiply for M input vectors.
2211// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2212kernel void matmul_vec_q4_batch(
2213    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2214    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2215    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2216    constant uint& num_tokens    [[buffer(3)]],  // M
2217    constant uint& rows          [[buffer(4)]],
2218    constant uint& cols          [[buffer(5)]],
2219    uint tgid [[threadgroup_position_in_grid]],
2220    uint tid [[thread_index_in_threadgroup]],
2221    uint simd_lane [[thread_index_in_simdgroup]],
2222    uint simd_id [[simdgroup_index_in_threadgroup]])
2223{
2224    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2225    uint token = tgid / row_tgs;
2226    uint tg_in_token = tgid % row_tgs;
2227    if (token >= num_tokens) return;
2228
2229    threadgroup float vec_tile[VEC_TILE_SIZE];
2230    device const float* input = inputs + token * cols;
2231    for (uint i = tid; i < cols; i += 256) {
2232        vec_tile[i] = input[i];
2233    }
2234    threadgroup_barrier(mem_flags::mem_threadgroup);
2235
2236    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2237    if (row_base >= rows) return;
2238
2239    uint blocks_per_row = cols / 32;
2240    uint row_bytes = blocks_per_row * 18;
2241
2242    device const uchar* r0 = matrix + row_base * row_bytes;
2243    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2244    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2245    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2246
2247    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2248
2249    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2250        uint bb = blk * 18;
2251        uint vb = blk * 32;
2252
2253        float sc0 = float(*(device const half*)(r0 + bb));
2254        float sc1 = float(*(device const half*)(r1 + bb));
2255        float sc2 = float(*(device const half*)(r2 + bb));
2256        float sc3 = float(*(device const half*)(r3 + bb));
2257
2258        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2259        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2260        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2261        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2262
2263        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2264        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2265        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2266        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2267        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2268        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2269        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2270        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2271
2272        float bd0=0, bd1=0, bd2=0, bd3=0;
2273        uchar4 b;
2274
2275        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2276        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2277        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2278        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2279
2280        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2281        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2282        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2283        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2284
2285        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2286        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2287        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2288        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2289
2290        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2291        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2292        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2293        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2294
2295        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2296    }
2297
2298    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2299    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2300
2301    device float* output = outputs + token * rows;
2302    if (simd_lane == 0) {
2303        if (row_base     < rows) output[row_base]     = sum0;
2304        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2305        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2306        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2307    }
2308}
2309
2310// ── copy_kv_batch ─────────────────────────────────────────────────────
2311// Copy K or V from a strided batch QKV buffer to the KV cache.
2312// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2313// dst layout: contiguous [max_seq, kv_dim] cache.
2314kernel void copy_kv_batch(
2315    device const float* src  [[buffer(0)]],  // batch QKV buffer (f32)
2316    device half* dst         [[buffer(1)]],  // KV cache (f16)
2317    constant uint& M         [[buffer(2)]],  // num tokens in batch
2318    constant uint& kv_dim    [[buffer(3)]],  // elements per KV vector
2319    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2320    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2321    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2322    uint id [[thread_position_in_grid]])
2323{
2324    uint total = M * kv_dim;
2325    if (id >= total) return;
2326    uint token = id / kv_dim;
2327    uint d = id % kv_dim;
2328    uint dst_off = (base_pos + token) * kv_dim + d;
2329    uint src_off = token * src_stride + src_offset + d;
2330    dst[dst_off] = half(src[src_off]);
2331}
2332
2333// ── attention_batch ───────────────────────────────────────────────────
2334// Batched causal attention for prefill. Processes M tokens in one dispatch.
2335// Each threadgroup handles one (token, head) pair.
2336// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2337// Causal masking: token i can only attend to positions 0..base_pos+i.
2338kernel void attention_batch(
2339    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2340    device const half*  k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim] f16
2341    device const half*  v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim] f16
2342    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2343    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2344    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2345    constant uint& num_heads         [[buffer(6)]],
2346    constant uint& num_kv_heads      [[buffer(7)]],
2347    constant uint& head_dim          [[buffer(8)]],
2348    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2349    uint tgid [[threadgroup_position_in_grid]],
2350    uint tid [[thread_index_in_threadgroup]],
2351    uint simd_lane [[thread_index_in_simdgroup]],
2352    uint simd_id [[simdgroup_index_in_threadgroup]])
2353{
2354    // Grid: M * num_heads threadgroups
2355    uint token_idx = tgid / num_heads;
2356    uint head = tgid % num_heads;
2357    if (token_idx >= M) return;
2358
2359    uint kv_head = head / (num_heads / num_kv_heads);
2360    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2361
2362    // Q offset uses strided layout (from batch QKV buffer)
2363    uint q_off = token_idx * q_stride + head * head_dim;
2364    // Output is contiguous [M, num_heads * head_dim]
2365    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2366
2367    // Shared memory for attention scores — sized to the effective max_seq_len
2368    // (4096 for all supported models) so long-context attention doesn't overflow.
2369    threadgroup float scores[ATTN_SCORES_SIZE];
2370
2371    // Step 1: Q * K^T with simdgroup reduction
2372    for (uint s = simd_id; s < seq_len; s += 8) {
2373        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2374        float dot = 0.0;
2375        for (uint d = simd_lane; d < head_dim; d += 32) {
2376            dot += q_batch[q_off + d] * k_cache[k_off + d];
2377        }
2378        dot = simd_sum(dot);
2379        if (simd_lane == 0) {
2380            scores[s] = dot * fast::rsqrt(float(head_dim));
2381        }
2382    }
2383    threadgroup_barrier(mem_flags::mem_threadgroup);
2384
2385    // Step 2: Softmax (cooperative)
2386    float local_max = -INFINITY;
2387    for (uint s = tid; s < seq_len; s += 256) {
2388        local_max = max(local_max, scores[s]);
2389    }
2390    local_max = simd_max(local_max);
2391    threadgroup float shared_max[8];
2392    if (simd_lane == 0) shared_max[simd_id] = local_max;
2393    threadgroup_barrier(mem_flags::mem_threadgroup);
2394    if (tid == 0) {
2395        float m = shared_max[0];
2396        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2397        shared_max[0] = m;
2398    }
2399    threadgroup_barrier(mem_flags::mem_threadgroup);
2400    float max_val = shared_max[0];
2401
2402    float local_sum = 0.0;
2403    for (uint s = tid; s < seq_len; s += 256) {
2404        scores[s] = fast::exp(scores[s] - max_val);
2405        local_sum += scores[s];
2406    }
2407    local_sum = simd_sum(local_sum);
2408    threadgroup float shared_sum[8];
2409    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2410    threadgroup_barrier(mem_flags::mem_threadgroup);
2411    if (tid == 0) {
2412        float total = 0.0;
2413        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2414        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2415    }
2416    threadgroup_barrier(mem_flags::mem_threadgroup);
2417    float inv_sum = shared_sum[0];
2418    for (uint s = tid; s < seq_len; s += 256) {
2419        scores[s] *= inv_sum;
2420    }
2421    threadgroup_barrier(mem_flags::mem_threadgroup);
2422
2423    // Step 3: scores * V using float4 vectorized loads
2424    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2425    // This is much better than the scalar version where only 64 of 256 threads are active.
2426    uint v_stride = num_kv_heads * head_dim;
2427    uint head_dim4 = head_dim / 4;
2428    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2429        uint d = d4 * 4;
2430        float4 acc = float4(0.0);
2431        uint v_base = kv_head * head_dim + d;
2432        uint seq_len4 = seq_len & ~3u;
2433        for (uint s = 0; s < seq_len4; s += 4) {
2434            float sc0 = scores[s];
2435            float sc1 = scores[s + 1];
2436            float sc2 = scores[s + 2];
2437            float sc3 = scores[s + 3];
2438            acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
2439                 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
2440                 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
2441                 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
2442        }
2443        for (uint s = seq_len4; s < seq_len; s++) {
2444            acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
2445        }
2446        *(device float4*)(output_batch + out_off + d) = acc;
2447    }
2448    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2449    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2450        float acc = 0.0;
2451        uint v_base = kv_head * head_dim + d;
2452        for (uint s = 0; s < seq_len; s++) {
2453            acc += scores[s] * v_cache[s * v_stride + v_base];
2454        }
2455        output_batch[out_off + d] = acc;
2456    }
2457}
2458
2459// ── attention_flash_batch ─────────────────────────────────────────────
2460// Streaming attention with online softmax.  Same grid as attention_batch
2461// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2462// matrix is never materialized.  K/V positions are processed in a tile of
2463// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2464// the standard flash-attention recurrence:
2465//
2466//     m_new   = max(m_old, tile_max)
2467//     alpha   = exp(m_old - m_new)
2468//     l_new   = alpha * l_old + sum(exp(S - m_new))
2469//     O_new   = alpha * O_old + sum(exp(S - m_new) * V)
2470//     O_final = O / l
2471//
2472// This removes the `scores[2048]` cap in attention_batch (which silently
2473// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2474// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2475//
2476// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2477constant constexpr uint FLASH_K_TILE = 32;
2478constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2479
2480kernel void attention_flash_batch(
2481    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2482    device const half*  k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim] f16
2483    device const half*  v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim] f16
2484    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2485    constant uint& M                 [[buffer(4)]],
2486    constant uint& base_pos          [[buffer(5)]],
2487    constant uint& num_heads         [[buffer(6)]],
2488    constant uint& num_kv_heads      [[buffer(7)]],
2489    constant uint& head_dim          [[buffer(8)]],
2490    constant uint& q_stride          [[buffer(9)]],
2491    uint tgid [[threadgroup_position_in_grid]],
2492    uint tid [[thread_index_in_threadgroup]],
2493    uint simd_lane [[thread_index_in_simdgroup]],
2494    uint simd_id [[simdgroup_index_in_threadgroup]])
2495{
2496    uint token_idx = tgid / num_heads;
2497    uint head = tgid % num_heads;
2498    if (token_idx >= M) return;
2499
2500    uint kv_head = head / (num_heads / num_kv_heads);
2501    uint seq_len = base_pos + token_idx + 1;  // causal: attend to [0, base_pos + token_idx]
2502    uint q_off = token_idx * q_stride + head * head_dim;
2503    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2504
2505    // Threadgroup state:
2506    //   q_sh:      Q vector for this (token, head), loaded once
2507    //   o_sh:      running output vector, updated each K tile
2508    //   scores_sh: scores for the current K tile only
2509    //   stats:     [running max, running sum]  (see flash-attention recurrence)
2510    threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2511    threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2512    threadgroup float scores_sh[FLASH_K_TILE];
2513    threadgroup float stats[2];
2514    threadgroup float sg_scratch[8];  // simdgroup-level reduction buffer
2515
2516    // --- Load Q (one row) and zero the running O ---
2517    for (uint d = tid; d < head_dim; d += 256) {
2518        q_sh[d] = q_batch[q_off + d];
2519        o_sh[d] = 0.0f;
2520    }
2521    if (tid == 0) {
2522        stats[0] = -INFINITY;
2523        stats[1] = 0.0f;
2524    }
2525    threadgroup_barrier(mem_flags::mem_threadgroup);
2526
2527    float scale = fast::rsqrt(float(head_dim));
2528    uint v_stride = num_kv_heads * head_dim;
2529    uint v_base = kv_head * head_dim;
2530
2531    // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2532    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2533        uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2534
2535        // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2536        // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2537        for (uint ti = simd_id; ti < tile_n; ti += 8) {
2538            uint k_off = (kv_base + ti) * v_stride + v_base;  // same layout as V stride
2539            float dot = 0.0f;
2540            for (uint d = simd_lane; d < head_dim; d += 32) {
2541                dot += q_sh[d] * k_cache[k_off + d];
2542            }
2543            dot = simd_sum(dot);
2544            if (simd_lane == 0) {
2545                scores_sh[ti] = dot * scale;
2546            }
2547        }
2548        threadgroup_barrier(mem_flags::mem_threadgroup);
2549
2550        // [2] Tile max via cooperative reduction.
2551        float local_max = -INFINITY;
2552        for (uint s = tid; s < tile_n; s += 256) {
2553            local_max = max(local_max, scores_sh[s]);
2554        }
2555        local_max = simd_max(local_max);
2556        if (simd_lane == 0) {
2557            sg_scratch[simd_id] = local_max;
2558        }
2559        threadgroup_barrier(mem_flags::mem_threadgroup);
2560        // [3] Merge with running max, compute alpha, rescale running l.
2561        float m_new;
2562        float alpha;
2563        if (tid == 0) {
2564            float tile_max = sg_scratch[0];
2565            for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2566            float m_old = stats[0];
2567            m_new = max(m_old, tile_max);
2568            // First iteration: m_old = -inf → alpha = 0 (reset O).
2569            alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2570            stats[0] = m_new;
2571            stats[1] *= alpha;
2572            // Broadcast via sg_scratch.
2573            sg_scratch[0] = alpha;
2574            sg_scratch[1] = m_new;
2575        }
2576        threadgroup_barrier(mem_flags::mem_threadgroup);
2577        alpha = sg_scratch[0];
2578        m_new = sg_scratch[1];
2579
2580        // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2581        for (uint d = tid; d < head_dim; d += 256) {
2582            o_sh[d] *= alpha;
2583        }
2584        for (uint s = tid; s < tile_n; s += 256) {
2585            scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2586        }
2587        threadgroup_barrier(mem_flags::mem_threadgroup);
2588
2589        // [5] Tile sum → update running l.
2590        float local_sum = 0.0f;
2591        for (uint s = tid; s < tile_n; s += 256) {
2592            local_sum += scores_sh[s];
2593        }
2594        local_sum = simd_sum(local_sum);
2595        if (simd_lane == 0) {
2596            sg_scratch[simd_id] = local_sum;
2597        }
2598        threadgroup_barrier(mem_flags::mem_threadgroup);
2599        if (tid == 0) {
2600            float tile_sum = 0.0f;
2601            for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2602            stats[1] += tile_sum;
2603        }
2604        threadgroup_barrier(mem_flags::mem_threadgroup);
2605
2606        // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2607        for (uint d = tid; d < head_dim; d += 256) {
2608            float acc = 0.0f;
2609            for (uint s = 0; s < tile_n; s++) {
2610                acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2611            }
2612            o_sh[d] += acc;
2613        }
2614        threadgroup_barrier(mem_flags::mem_threadgroup);
2615    }
2616
2617    // --- Normalize and write output ---
2618    float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2619    for (uint d = tid; d < head_dim; d += 256) {
2620        output_batch[out_off + d] = o_sh[d] * inv_l;
2621    }
2622}
2623
2624// ── attention_mma_flash_batch ─────────────────────────────────────────
2625// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2626// both Q·K^T and P·V.  Processes Q_BLOCK=8 tokens of one head per
2627// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2628// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2629// arithmetic.
2630//
2631// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2632// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2633// to attention_batch / attention_flash_batch otherwise.
2634//
2635// Online softmax recurrence is identical to attention_flash_batch but
2636// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2637constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2638constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2639constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2640
2641kernel void attention_mma_flash_batch(
2642    device const float* q_batch      [[buffer(0)]],
2643    device const half*  k_cache      [[buffer(1)]],
2644    device const half*  v_cache      [[buffer(2)]],
2645    device float* output_batch       [[buffer(3)]],
2646    constant uint& M                 [[buffer(4)]],
2647    constant uint& base_pos          [[buffer(5)]],
2648    constant uint& num_heads         [[buffer(6)]],
2649    constant uint& num_kv_heads      [[buffer(7)]],
2650    constant uint& head_dim          [[buffer(8)]],
2651    constant uint& q_stride          [[buffer(9)]],
2652    uint2 tgid [[threadgroup_position_in_grid]],
2653    uint tid [[thread_index_in_threadgroup]],
2654    uint simd_lane [[thread_index_in_simdgroup]],
2655    uint simd_id [[simdgroup_index_in_threadgroup]])
2656{
2657    uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2658    uint head = tgid.y;
2659    if (q_block_start >= M) return;
2660    uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2661
2662    uint kv_head = head / (num_heads / num_kv_heads);
2663    // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2664    // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2665    uint seq_len = base_pos + q_block_start + q_valid;
2666    float scale = fast::rsqrt(float(head_dim));
2667
2668    uint kv_stride = num_kv_heads * head_dim;
2669    uint kv_base_off = kv_head * head_dim;
2670
2671    // ── Threadgroup memory ──
2672    // q_sh:  [Q_BLOCK, head_dim] half  — Q tile, loaded once
2673    // k_sh:  [K_BLOCK, head_dim] half  — K tile, refreshed per kv_base iter
2674    // v_sh:  [K_BLOCK, head_dim] half  — V tile, refreshed per kv_base iter
2675    // s_sh:  [Q_BLOCK, K_BLOCK] float  — raw Q·K^T scores, then scaled+masked
2676    // p_sh:  [Q_BLOCK, K_BLOCK] half   — softmax probabilities (for P·V MMA)
2677    // o_sh:  [Q_BLOCK, head_dim] float — running output accumulator
2678    // m_sh:  [Q_BLOCK] float           — running max per Q row
2679    // l_sh:  [Q_BLOCK] float           — running softmax denominator per Q row
2680    // scratch: 4*Q_BLOCK floats        — per-row reduction scratch
2681    threadgroup half  q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2682    threadgroup half  k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2683    threadgroup half  v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2684    threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2685    threadgroup half  p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2686    threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2687    threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2688    threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2689    threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2690
2691    // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2692    uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2693    for (uint i = tid; i < qblock_elems; i += 128) {
2694        uint q = i / head_dim;
2695        uint d = i % head_dim;
2696        if (q < q_valid) {
2697            uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2698            q_sh[q * head_dim + d] = half(q_batch[q_off]);
2699        } else {
2700            q_sh[q * head_dim + d] = half(0);
2701        }
2702        o_sh[q * head_dim + d] = 0.0f;
2703    }
2704    if (tid < FLASH_MMA_Q_BLOCK) {
2705        m_sh[tid] = -INFINITY;
2706        l_sh[tid] = 0.0f;
2707    }
2708    threadgroup_barrier(mem_flags::mem_threadgroup);
2709
2710    // ── Stream K/V in K_BLOCK chunks ──
2711    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2712        uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2713
2714        // Load K and V tile into TG memory (as half).
2715        uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2716        for (uint i = tid; i < kv_tile_elems; i += 128) {
2717            uint k_pos = i / head_dim;
2718            uint d = i % head_dim;
2719            if (k_pos < tile_n) {
2720                uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2721                k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2722                v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2723            } else {
2724                k_sh[k_pos * head_dim + d] = half(0);
2725                v_sh[k_pos * head_dim + d] = half(0);
2726            }
2727        }
2728        threadgroup_barrier(mem_flags::mem_threadgroup);
2729
2730        // ── Phase 1: S = Q @ K^T via MMA ──
2731        // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2732        // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2733        // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2734        {
2735            simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2736            uint dim_chunks = head_dim / 8;
2737            for (uint dc = 0; dc < dim_chunks; dc++) {
2738                simdgroup_matrix<half, 8, 8> A, B;
2739                // A = Q[0:8, dc*8 : dc*8+8]  (rows of Q, no transpose)
2740                simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2741                // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2742                // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2743                // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2744                // transpose=true, which places it in the register as K^T of that sub-block.
2745                simdgroup_load(B,
2746                    k_sh + (simd_id * 8) * head_dim + dc * 8,
2747                    head_dim, ulong2(0, 0), true);
2748                simdgroup_multiply_accumulate(C, A, B, C);
2749            }
2750            // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2751            simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2752        }
2753        threadgroup_barrier(mem_flags::mem_threadgroup);
2754
2755        // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2756        // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2757        uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2758        for (uint i = tid; i < s_elems; i += 128) {
2759            uint q = i / FLASH_MMA_K_BLOCK;
2760            uint k = i % FLASH_MMA_K_BLOCK;
2761            uint global_q = q_block_start + q;
2762            uint global_kv = kv_base + k;
2763            bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2764            s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2765        }
2766        threadgroup_barrier(mem_flags::mem_threadgroup);
2767
2768        // ── Phase 2b: per-row max via simdgroup reduction ──
2769        // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2770        // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2771        {
2772            uint row_base = simd_id * 2;
2773            for (uint qr = 0; qr < 2; qr++) {
2774                uint q = row_base + qr;
2775                float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2776                float row_max = simd_max(my);
2777                if (simd_lane == 0) {
2778                    scratch[q] = row_max;  // tile_max[q]
2779                }
2780            }
2781        }
2782        threadgroup_barrier(mem_flags::mem_threadgroup);
2783
2784        // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2785        if (tid < FLASH_MMA_Q_BLOCK) {
2786            uint q = tid;
2787            float m_old = m_sh[q];
2788            float tile_max = scratch[q];
2789            float m_new = max(m_old, tile_max);
2790            float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2791            m_sh[q] = m_new;
2792            l_sh[q] = l_sh[q] * alpha;
2793            // scratch[q]               = m_new   (for phase 2d)
2794            // scratch[Q_BLOCK + q]     = alpha   (for phase 3)
2795            scratch[q] = m_new;
2796            scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2797        }
2798        threadgroup_barrier(mem_flags::mem_threadgroup);
2799
2800        // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2801        {
2802            uint row_base = simd_id * 2;
2803            for (uint qr = 0; qr < 2; qr++) {
2804                uint q = row_base + qr;
2805                float m_new = scratch[q];
2806                float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2807                p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2808                float row_sum = simd_sum(p);
2809                if (simd_lane == 0) {
2810                    scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2811                }
2812            }
2813        }
2814        threadgroup_barrier(mem_flags::mem_threadgroup);
2815
2816        // ── Phase 2e: l_sh += tile_sum ──
2817        if (tid < FLASH_MMA_Q_BLOCK) {
2818            uint q = tid;
2819            l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2820        }
2821
2822        // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2823        threadgroup_barrier(mem_flags::mem_threadgroup);
2824        for (uint i = tid; i < qblock_elems; i += 128) {
2825            uint q = i / head_dim;
2826            float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2827            o_sh[i] *= alpha;
2828        }
2829        threadgroup_barrier(mem_flags::mem_threadgroup);
2830
2831        // ── Phase 4: O += P @ V via MMA ──
2832        // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2833        // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2834        // For head_dim=64: 16 dims/sg = 2 tiles.  head_dim=128: 32 dims/sg = 4 tiles.
2835        {
2836            uint dims_per_sg = head_dim / 4;       // 16 or 32
2837            uint tiles_per_sg = dims_per_sg / 8;   // 2 or 4
2838            uint sg_d_base = simd_id * dims_per_sg;
2839            for (uint t = 0; t < tiles_per_sg; t++) {
2840                uint d_base = sg_d_base + t * 8;
2841                simdgroup_matrix<float, 8, 8> O_acc;
2842                simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2843                uint k_chunks = FLASH_MMA_K_BLOCK / 8;  // 4
2844                for (uint kc = 0; kc < k_chunks; kc++) {
2845                    simdgroup_matrix<half, 8, 8> A, B;
2846                    simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2847                                   ulong2(0, 0), false);
2848                    simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2849                                   ulong2(0, 0), false);
2850                    simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2851                }
2852                simdgroup_store(O_acc, o_sh + d_base, head_dim);
2853            }
2854        }
2855        threadgroup_barrier(mem_flags::mem_threadgroup);
2856    }
2857
2858    // ── Finalize: O /= l, write to output ──
2859    for (uint i = tid; i < qblock_elems; i += 128) {
2860        uint q = i / head_dim;
2861        uint d = i % head_dim;
2862        if (q < q_valid) {
2863            float l = l_sh[q];
2864            float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2865            uint token_idx = q_block_start + q;
2866            uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2867            output_batch[out_off] = o_sh[i] * inv_l;
2868        }
2869    }
2870}
2871
2872// ── rope_qk_batch ─────────────────────────────────────────────────────
2873// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2874// launch + memory barrier per layer. Both Q and K live in the same
2875// qkv_data buffer at different offsets within each token's row.
2876// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2877kernel void rope_qk_batch(
2878    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2879    constant uint& M                 [[buffer(1)]],   // num tokens
2880    constant uint& base_pos          [[buffer(2)]],   // starting position
2881    constant uint& num_q_heads       [[buffer(3)]],
2882    constant uint& num_kv_heads      [[buffer(4)]],
2883    constant uint& head_dim          [[buffer(5)]],
2884    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2885    constant float& theta            [[buffer(7)]],
2886    uint id [[thread_position_in_grid]])
2887{
2888    uint half_dim = head_dim / 2;
2889    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2890    uint token = id / total_pairs;
2891    uint pair = id % total_pairs;
2892    if (token >= M) return;
2893
2894    uint pos = base_pos + token;
2895    uint q_pairs = num_q_heads * half_dim;
2896
2897    uint h, i, offset;
2898    if (pair < q_pairs) {
2899        // Q head
2900        h = pair / half_dim;
2901        i = pair % half_dim;
2902        offset = token * qkv_stride + h * head_dim + i * 2;
2903    } else {
2904        // K head
2905        uint kp = pair - q_pairs;
2906        h = kp / half_dim;
2907        i = kp % half_dim;
2908        uint k_start = num_q_heads * head_dim;
2909        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2910    }
2911
2912    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2913    float angle = float(pos) * freq;
2914    float cos_val = cos(angle);
2915    float sin_val = sin(angle);
2916
2917    float x0 = qkv_data[offset];
2918    float x1 = qkv_data[offset + 1];
2919    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2920    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2921}
2922
2923// ── copy_kv_both_batch ────────────────────────────────────────────────
2924// Fused K+V cache copy in a single dispatch: copies both K and V from
2925// the strided batch QKV buffer to their respective KV cache buffers.
2926// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2927kernel void copy_kv_both_batch(
2928    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride] f32
2929    device half* k_dst         [[buffer(1)]],  // K cache [max_seq, kv_dim] f16
2930    device half* v_dst         [[buffer(2)]],  // V cache [max_seq, kv_dim] f16
2931    constant uint& M           [[buffer(3)]],  // num tokens in batch
2932    constant uint& kv_dim      [[buffer(4)]],  // elements per KV vector
2933    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2934    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2935    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2936    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2937    uint id [[thread_position_in_grid]])
2938{
2939    // Total elements = M * kv_dim * 2 (K + V)
2940    uint total_kv = M * kv_dim;
2941    if (id >= total_kv * 2) return;
2942
2943    uint is_v = id / total_kv;        // 0 = K, 1 = V
2944    uint local_id = id % total_kv;
2945    uint token = local_id / kv_dim;
2946    uint d = local_id % kv_dim;
2947
2948    uint dst_off = (base_pos + token) * kv_dim + d;
2949    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2950
2951    if (is_v) {
2952        v_dst[dst_off] = half(src[src_off]);
2953    } else {
2954        k_dst[dst_off] = half(src[src_off]);
2955    }
2956}
2957"#
2958    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2959    .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2960}
2961
2962// ---------------------------------------------------------------------------
2963// model.rs generation
2964// ---------------------------------------------------------------------------
2965
2966fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2967    let mut code = String::with_capacity(48 * 1024);
2968    emit_model_header(&mut code, config)?;
2969    emit_metal_model_struct(&mut code, config)?;
2970    emit_layer_buffers_struct(&mut code, config)?;
2971    emit_metal_model_impl(&mut code, config)?;
2972    emit_helper_functions(&mut code)?;
2973    Ok(code)
2974}
2975
2976fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2977    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2978    writeln!(
2979        code,
2980        "//! Model: {} ({} layers, hidden={})",
2981        config.architecture, config.num_layers, config.hidden_size
2982    )?;
2983    writeln!(code, "//!")?;
2984    writeln!(
2985        code,
2986        "//! Uses native Metal compute pipelines via the metal crate."
2987    )?;
2988    writeln!(code)?;
2989    writeln!(code, "#![allow(dead_code)]")?;
2990    writeln!(code)?;
2991    writeln!(code, "use metal::*;")?;
2992    writeln!(code, "#[allow(unused_imports)]")?;
2993    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2994    writeln!(code, "use std::mem;")?;
2995    writeln!(code)?;
2996
2997    // Model constants
2998    writeln!(
2999        code,
3000        "// ── Model constants ──────────────────────────────────"
3001    )?;
3002    writeln!(
3003        code,
3004        "pub const HIDDEN_SIZE: usize = {};",
3005        config.hidden_size
3006    )?;
3007    writeln!(
3008        code,
3009        "pub const INTERMEDIATE_SIZE: usize = {};",
3010        config.intermediate_size
3011    )?;
3012    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
3013    writeln!(
3014        code,
3015        "pub const NUM_HEADS: usize = {};",
3016        config.num_attention_heads
3017    )?;
3018    writeln!(
3019        code,
3020        "pub const NUM_KV_HEADS: usize = {};",
3021        config.num_kv_heads
3022    )?;
3023    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
3024    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
3025    let effective_seq_len = config.max_seq_len.min(4096);
3026    writeln!(
3027        code,
3028        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
3029        effective_seq_len, config.max_seq_len
3030    )?;
3031    writeln!(
3032        code,
3033        "pub const RMS_NORM_EPS: f32 = {:e};",
3034        config.rms_norm_eps
3035    )?;
3036    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3037    writeln!(
3038        code,
3039        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3040    )?;
3041    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3042    writeln!(code)?;
3043
3044    Ok(())
3045}
3046
3047fn emit_metal_model_struct(
3048    code: &mut String,
3049    config: &ModelConfig,
3050) -> Result<(), MetalCodegenError> {
3051    writeln!(
3052        code,
3053        "// ── MetalModel ──────────────────────────────────────────"
3054    )?;
3055    writeln!(code)?;
3056    writeln!(
3057        code,
3058        "/// Metal-accelerated transformer model for Apple Silicon."
3059    )?;
3060    writeln!(code, "///")?;
3061    writeln!(
3062        code,
3063        "/// Uses unified memory for zero-copy weight access and native Metal"
3064    )?;
3065    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3066    writeln!(code, "pub struct MetalModel {{")?;
3067    writeln!(code, "    device: Device,")?;
3068    writeln!(code, "    queue: CommandQueue,")?;
3069    writeln!(code)?;
3070    writeln!(code, "    // ── Compute pipelines ──")?;
3071    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
3072    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
3073    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
3074    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
3075    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
3076    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
3077    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
3078    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
3079    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
3080    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
3081    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
3082    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
3083    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
3084    writeln!(
3085        code,
3086        "    copy_f32_to_f16_offset_pipeline: ComputePipelineState,"
3087    )?;
3088    writeln!(code)?;
3089    writeln!(code, "    // ── Batched prefill pipelines ──")?;
3090    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
3091    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
3092    writeln!(
3093        code,
3094        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3095    )?;
3096    writeln!(code, "    matmul_q8_mma_pipeline: ComputePipelineState,")?;
3097    writeln!(code, "    matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3098    writeln!(
3099        code,
3100        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3101    )?;
3102    writeln!(
3103        code,
3104        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3105    )?;
3106    writeln!(
3107        code,
3108        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3109    )?;
3110    if config.qkv_bias {
3111        writeln!(code, "    add_bias_batch_pipeline: ComputePipelineState,")?;
3112    }
3113    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
3114    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
3115    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
3116    writeln!(
3117        code,
3118        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
3119    )?;
3120    writeln!(
3121        code,
3122        "    add_inplace_batch_pipeline: ComputePipelineState,"
3123    )?;
3124    writeln!(
3125        code,
3126        "    copy_embedding_batch_pipeline: ComputePipelineState,"
3127    )?;
3128    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
3129    writeln!(
3130        code,
3131        "    attention_flash_batch_pipeline: ComputePipelineState,"
3132    )?;
3133    writeln!(
3134        code,
3135        "    attention_mma_flash_batch_pipeline: ComputePipelineState,"
3136    )?;
3137    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
3138    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
3139    writeln!(
3140        code,
3141        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
3142    )?;
3143    writeln!(code)?;
3144    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
3145    writeln!(
3146        code,
3147        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3148    )?;
3149    writeln!(code, "    embed_buf: Buffer,")?;
3150    writeln!(code)?;
3151    writeln!(code, "    /// Per-layer weight buffers")?;
3152    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
3153    writeln!(code)?;
3154    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
3155    writeln!(code, "    norm_buf: Buffer,")?;
3156    writeln!(code)?;
3157    writeln!(
3158        code,
3159        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3160    )?;
3161    writeln!(code, "    lm_head_buf: Buffer,")?;
3162    writeln!(code)?;
3163    writeln!(
3164        code,
3165        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
3166    )?;
3167    writeln!(code, "    hidden_buf: Buffer,")?;
3168    writeln!(code, "    residual_buf: Buffer,")?;
3169    writeln!(code, "    normed_buf: Buffer,")?;
3170    writeln!(
3171        code,
3172        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3173    )?;
3174    writeln!(code, "    qkv_buf: Buffer,")?;
3175    writeln!(code, "    attn_out_buf: Buffer,")?;
3176    writeln!(code, "    attn_proj_buf: Buffer,")?;
3177    writeln!(
3178        code,
3179        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3180    )?;
3181    writeln!(code, "    gate_up_buf: Buffer,")?;
3182    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
3183    writeln!(code, "    ffn_out_buf: Buffer,")?;
3184    writeln!(code, "    add_tmp_buf: Buffer,")?;
3185    writeln!(code, "    logits_buf: Buffer,")?;
3186    writeln!(code)?;
3187    writeln!(code, "    // ── Batched prefill working buffers ──")?;
3188    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3189    writeln!(code, "    batch_hidden_buf: Buffer,")?;
3190    writeln!(
3191        code,
3192        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3193    )?;
3194    writeln!(code, "    batch_residual_buf: Buffer,")?;
3195    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3196    writeln!(code, "    batch_qkv_buf: Buffer,")?;
3197    writeln!(
3198        code,
3199        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3200    )?;
3201    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
3202    writeln!(
3203        code,
3204        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3205    )?;
3206    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
3207    writeln!(
3208        code,
3209        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3210    )?;
3211    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
3212    writeln!(
3213        code,
3214        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3215    )?;
3216    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
3217    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3218    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
3219    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
3220    writeln!(code, "    batch_tokens_buf: Buffer,")?;
3221    writeln!(code, "    /// Positions buffer for batch RoPE")?;
3222    writeln!(code, "    batch_positions_buf: Buffer,")?;
3223    writeln!(code)?;
3224    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
3225    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
3226    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
3227    writeln!(code)?;
3228    writeln!(code, "    // ── Inference state ──")?;
3229    writeln!(code, "    pos: usize,")?;
3230    writeln!(code)?;
3231    writeln!(
3232        code,
3233        "    /// Previous command buffer for double-buffered prefill."
3234    )?;
3235    writeln!(
3236        code,
3237        "    /// While the GPU executes token N, the CPU can encode token N+1."
3238    )?;
3239    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
3240    writeln!(code, "}}")?;
3241    writeln!(code)?;
3242
3243    Ok(())
3244}
3245
3246fn emit_layer_buffers_struct(
3247    code: &mut String,
3248    config: &ModelConfig,
3249) -> Result<(), MetalCodegenError> {
3250    writeln!(
3251        code,
3252        "/// Per-layer weight buffers for attention and FFN projections."
3253    )?;
3254    writeln!(code, "struct LayerBuffers {{")?;
3255    writeln!(code, "    attn_norm: Buffer,")?;
3256    writeln!(
3257        code,
3258        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3259    )?;
3260    writeln!(code, "    qkv_weight: Buffer,")?;
3261    if config.qkv_bias {
3262        writeln!(
3263            code,
3264            "    /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3265        )?;
3266        writeln!(code, "    qkv_bias: Buffer,")?;
3267    }
3268    writeln!(code, "    o_weight: Buffer,")?;
3269    writeln!(code, "    ffn_norm: Buffer,")?;
3270    writeln!(
3271        code,
3272        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3273    )?;
3274    writeln!(code, "    gate_up_weight: Buffer,")?;
3275    writeln!(code, "    down_weight: Buffer,")?;
3276    writeln!(code, "}}")?;
3277    writeln!(code)?;
3278
3279    Ok(())
3280}
3281
3282fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3283    let hidden = config.hidden_size;
3284    let intermediate = config.intermediate_size;
3285    let _num_layers = config.num_layers;
3286    let _num_heads = config.num_attention_heads;
3287    let num_kv_heads = config.num_kv_heads;
3288    let head_dim = config.head_dim;
3289    let vocab = config.vocab_size;
3290    let effective_seq_len = config.max_seq_len.min(4096);
3291    let is_q8 = config.dtype == DType::Q8_0;
3292    let is_q4 = config.dtype == DType::Q4_0;
3293    let kv_dim = num_kv_heads * head_dim;
3294
3295    writeln!(code, "impl MetalModel {{")?;
3296
3297    // ── new() ──
3298    writeln!(
3299        code,
3300        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3301    )?;
3302    writeln!(code, "    ///")?;
3303    writeln!(
3304        code,
3305        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
3306    )?;
3307    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
3308    writeln!(
3309        code,
3310        "        let device = Device::system_default().expect(\"no Metal device found\");"
3311    )?;
3312    writeln!(code, "        let queue = device.new_command_queue();")?;
3313    writeln!(code)?;
3314
3315    // Compile shaders
3316    writeln!(
3317        code,
3318        "        // Compile Metal shaders from embedded source"
3319    )?;
3320    writeln!(
3321        code,
3322        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
3323    )?;
3324    writeln!(
3325        code,
3326        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3327    )?;
3328    writeln!(
3329        code,
3330        "            .expect(\"failed to compile Metal shaders\");"
3331    )?;
3332    writeln!(code)?;
3333
3334    // Create compute pipelines
3335    writeln!(code, "        // Create compute pipelines")?;
3336    for (var, fn_name) in [
3337        ("matmul_pipeline", "matmul_vec"),
3338        ("matmul_q8_pipeline", "matmul_vec_q8"),
3339        ("matmul_q4_pipeline", "matmul_vec_q4"),
3340        ("rms_norm_pipeline", "rms_norm"),
3341        ("rope_pipeline", "rope"),
3342        ("softmax_pipeline", "softmax"),
3343        ("silu_mul_pipeline", "silu_mul"),
3344        ("silu_mul_fused_pipeline", "silu_mul_fused"),
3345        ("add_pipeline", "elementwise_add"),
3346        ("attention_pipeline", "attention"),
3347        ("add_inplace_pipeline", "add_inplace"),
3348        ("copy_pipeline", "copy_buffer"),
3349        ("copy_offset_pipeline", "copy_offset"),
3350        ("copy_f32_to_f16_offset_pipeline", "copy_f32_to_f16_offset"),
3351        ("matmul_batch_pipeline", "matmul_vec_batch"),
3352        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3353        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3354        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3355        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3356        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3357        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3358        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3359        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3360        ("rms_norm_batch_pipeline", "rms_norm_batch"),
3361        ("rope_batch_pipeline", "rope_batch"),
3362        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3363        ("add_inplace_batch_pipeline", "add_inplace_batch"),
3364        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3365        ("attention_batch_pipeline", "attention_batch"),
3366        ("attention_flash_batch_pipeline", "attention_flash_batch"),
3367        (
3368            "attention_mma_flash_batch_pipeline",
3369            "attention_mma_flash_batch",
3370        ),
3371        ("copy_kv_batch_pipeline", "copy_kv_batch"),
3372        ("rope_qk_batch_pipeline", "rope_qk_batch"),
3373        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3374    ] {
3375        writeln!(
3376            code,
3377            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3378        )?;
3379    }
3380    if config.qkv_bias {
3381        writeln!(
3382            code,
3383            "        let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3384        )?;
3385    }
3386    writeln!(code)?;
3387
3388    // Weight loading
3389    writeln!(
3390        code,
3391        "        // Load weights into Metal shared-memory buffers"
3392    )?;
3393    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
3394    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3395    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
3396    writeln!(code)?;
3397    writeln!(
3398        code,
3399        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
3400    )?;
3401    writeln!(code)?;
3402    writeln!(
3403        code,
3404        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3405    )?;
3406    writeln!(
3407        code,
3408        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3409    )?;
3410    writeln!(code, "            let byte_len = n * f32_size;")?;
3411    writeln!(code, "            let cur = cursor.get();")?;
3412    writeln!(
3413        code,
3414        "            let data = &weights[cur..cur + byte_len];"
3415    )?;
3416    writeln!(code, "            cursor.set(cur + byte_len);")?;
3417    writeln!(code, "            device.new_buffer_with_data(")?;
3418    writeln!(code, "                data.as_ptr() as *const _,")?;
3419    writeln!(code, "                byte_len as u64,")?;
3420    writeln!(
3421        code,
3422        "                MTLResourceOptions::StorageModeShared,"
3423    )?;
3424    writeln!(code, "            )")?;
3425    writeln!(code, "        }};")?;
3426    writeln!(code)?;
3427
3428    if is_q8 {
3429        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
3430        // We load them directly into Metal buffers without dequantizing,
3431        // and use the matmul_vec_q8 shader that operates on quantized data.
3432        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
3433        writeln!(
3434            code,
3435            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3436        )?;
3437        writeln!(
3438            code,
3439            "        // as raw bytes into a Metal buffer (no dequantization)."
3440        )?;
3441        writeln!(
3442            code,
3443            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3444        )?;
3445        writeln!(
3446            code,
3447            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3448        )?;
3449        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3450        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3451        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3452        writeln!(code, "            let cur = cursor.get();")?;
3453        writeln!(
3454            code,
3455            "            let data = &weights[cur..cur + total_raw];"
3456        )?;
3457        writeln!(code, "            cursor.set(cur + total_raw);")?;
3458        writeln!(code, "            device.new_buffer_with_data(")?;
3459        writeln!(code, "                data.as_ptr() as *const _,")?;
3460        writeln!(code, "                total_raw as u64,")?;
3461        writeln!(
3462            code,
3463            "                MTLResourceOptions::StorageModeShared,"
3464        )?;
3465        writeln!(code, "            )")?;
3466        writeln!(code, "        }};")?;
3467        writeln!(code)?;
3468        writeln!(
3469            code,
3470            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3471        )?;
3472        writeln!(
3473            code,
3474            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3475        )?;
3476        writeln!(
3477            code,
3478            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3479        )?;
3480        writeln!(
3481            code,
3482            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3483        )?;
3484        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3485        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3486        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3487        writeln!(code, "            let cur = cursor.get();")?;
3488        writeln!(
3489            code,
3490            "            let data = &weights[cur..cur + total_raw];"
3491        )?;
3492        writeln!(code, "            cursor.set(cur + total_raw);")?;
3493        writeln!(code, "            device.new_buffer_with_data(")?;
3494        writeln!(code, "                data.as_ptr() as *const _,")?;
3495        writeln!(code, "                total_raw as u64,")?;
3496        writeln!(
3497            code,
3498            "                MTLResourceOptions::StorageModeShared,"
3499        )?;
3500        writeln!(code, "            )")?;
3501        writeln!(code, "        }};")?;
3502        writeln!(code)?;
3503    }
3504
3505    if is_q4 {
3506        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3507        // We load them directly into Metal buffers without dequantizing,
3508        // and use the matmul_vec_q4 shader that operates on quantized data.
3509        // This quarters GPU memory usage vs f32 dequantization.
3510        writeln!(
3511            code,
3512            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3513        )?;
3514        writeln!(
3515            code,
3516            "        // as raw bytes into a Metal buffer (no dequantization)."
3517        )?;
3518        writeln!(
3519            code,
3520            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3521        )?;
3522        writeln!(
3523            code,
3524            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3525        )?;
3526        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3527        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3528        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3529        writeln!(code, "            let cur = cursor.get();")?;
3530        writeln!(
3531            code,
3532            "            let data = &weights[cur..cur + total_raw];"
3533        )?;
3534        writeln!(code, "            cursor.set(cur + total_raw);")?;
3535        writeln!(code, "            device.new_buffer_with_data(")?;
3536        writeln!(code, "                data.as_ptr() as *const _,")?;
3537        writeln!(code, "                total_raw as u64,")?;
3538        writeln!(
3539            code,
3540            "                MTLResourceOptions::StorageModeShared,"
3541        )?;
3542        writeln!(code, "            )")?;
3543        writeln!(code, "        }};")?;
3544        writeln!(code)?;
3545        writeln!(
3546            code,
3547            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3548        )?;
3549        writeln!(
3550            code,
3551            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3552        )?;
3553        writeln!(
3554            code,
3555            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3556        )?;
3557        writeln!(
3558            code,
3559            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3560        )?;
3561        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3562        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3563        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3564        writeln!(code, "            let cur = cursor.get();")?;
3565        writeln!(
3566            code,
3567            "            let data = &weights[cur..cur + total_raw];"
3568        )?;
3569        writeln!(code, "            cursor.set(cur + total_raw);")?;
3570        writeln!(code, "            device.new_buffer_with_data(")?;
3571        writeln!(code, "                data.as_ptr() as *const _,")?;
3572        writeln!(code, "                total_raw as u64,")?;
3573        writeln!(
3574            code,
3575            "                MTLResourceOptions::StorageModeShared,"
3576        )?;
3577        writeln!(code, "            )")?;
3578        writeln!(code, "        }};")?;
3579        writeln!(code)?;
3580    }
3581
3582    writeln!(
3583        code,
3584        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3585    )?;
3586    writeln!(code)?;
3587
3588    // Per-layer weights
3589    writeln!(
3590        code,
3591        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3592    )?;
3593    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3594
3595    // attn_norm is always f32
3596    writeln!(
3597        code,
3598        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3599    )?;
3600
3601    let qkv_rows = hidden + 2 * kv_dim;
3602    if is_q8 {
3603        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3604        writeln!(
3605            code,
3606            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3607        )?;
3608        if config.qkv_bias {
3609            writeln!(
3610                code,
3611                "            // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3612            )?;
3613            writeln!(
3614                code,
3615                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3616            )?;
3617        }
3618        writeln!(
3619            code,
3620            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3621        )?;
3622    } else if is_q4 {
3623        writeln!(
3624            code,
3625            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3626        )?;
3627        if config.qkv_bias {
3628            writeln!(
3629                code,
3630                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3631            )?;
3632        }
3633        writeln!(
3634            code,
3635            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3636        )?;
3637    } else {
3638        writeln!(
3639            code,
3640            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3641        )?;
3642        if config.qkv_bias {
3643            writeln!(
3644                code,
3645                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3646            )?;
3647        }
3648        writeln!(
3649            code,
3650            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3651        )?;
3652    }
3653
3654    // ffn_norm is always f32
3655    writeln!(
3656        code,
3657        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3658    )?;
3659
3660    let gate_up_rows = 2 * intermediate;
3661    if is_q8 {
3662        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3663        writeln!(
3664            code,
3665            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3666        )?;
3667        writeln!(
3668            code,
3669            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3670        )?;
3671    } else if is_q4 {
3672        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3673        writeln!(
3674            code,
3675            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3676        )?;
3677        writeln!(
3678            code,
3679            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3680        )?;
3681    } else {
3682        // Fused gate+up weight: read both as a single contiguous f32 buffer
3683        writeln!(
3684            code,
3685            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3686        )?;
3687        writeln!(
3688            code,
3689            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3690        )?;
3691    }
3692
3693    writeln!(code, "            layers.push(LayerBuffers {{")?;
3694    writeln!(code, "                attn_norm,")?;
3695    writeln!(code, "                qkv_weight,")?;
3696    if config.qkv_bias {
3697        writeln!(code, "                qkv_bias,")?;
3698    }
3699    writeln!(code, "                o_weight,")?;
3700    writeln!(code, "                ffn_norm,")?;
3701    writeln!(code, "                gate_up_weight,")?;
3702    writeln!(code, "                down_weight,")?;
3703    writeln!(code, "            }});")?;
3704    writeln!(code, "        }}")?;
3705    writeln!(code)?;
3706
3707    // final_norm is always f32
3708    writeln!(
3709        code,
3710        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3711    )?;
3712    writeln!(code)?;
3713
3714    // lm_head
3715    if is_q8 {
3716        writeln!(
3717            code,
3718            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3719        )?;
3720    } else if is_q4 {
3721        writeln!(
3722            code,
3723            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3724        )?;
3725    } else {
3726        writeln!(
3727            code,
3728            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3729        )?;
3730    }
3731    writeln!(code)?;
3732
3733    // Working buffers
3734    let hidden_bytes = hidden * 4;
3735    let _kv_dim_bytes = kv_dim * 4;
3736    let intermediate_bytes = intermediate * 4;
3737    let vocab_bytes = vocab * 4;
3738    // KV cache is stored as f16 (2 bytes/element) to halve attention memory
3739    // bandwidth in long-context decode.
3740    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 2;
3741
3742    writeln!(
3743        code,
3744        "        // Allocate working buffers (shared memory for zero-copy)"
3745    )?;
3746    writeln!(
3747        code,
3748        "        let opts = MTLResourceOptions::StorageModeShared;"
3749    )?;
3750    writeln!(
3751        code,
3752        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3753    )?;
3754    writeln!(
3755        code,
3756        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3757    )?;
3758    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3759    writeln!(
3760        code,
3761        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3762    )?;
3763    writeln!(
3764        code,
3765        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3766    )?;
3767    writeln!(
3768        code,
3769        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3770    )?;
3771    writeln!(
3772        code,
3773        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3774    )?;
3775    writeln!(
3776        code,
3777        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3778    )?;
3779    let gate_up_buf_bytes = 2 * intermediate * 4;
3780    writeln!(
3781        code,
3782        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3783    )?;
3784    writeln!(
3785        code,
3786        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3787    )?;
3788    writeln!(
3789        code,
3790        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3791    )?;
3792    writeln!(
3793        code,
3794        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3795    )?;
3796    writeln!(
3797        code,
3798        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3799    )?;
3800    writeln!(
3801        code,
3802        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3803    )?;
3804    writeln!(code)?;
3805
3806    // Batch prefill working buffers
3807    let batch_hidden_bytes = hidden * 4; // per-token
3808    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3809    let batch_gate_up_bytes = 2 * intermediate * 4;
3810    let batch_intermediate_bytes = intermediate * 4;
3811    writeln!(
3812        code,
3813        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3814    )?;
3815    writeln!(
3816        code,
3817        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3818    )?;
3819    writeln!(
3820        code,
3821        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3822    )?;
3823    writeln!(
3824        code,
3825        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3826    )?;
3827    writeln!(
3828        code,
3829        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3830    )?;
3831    writeln!(
3832        code,
3833        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3834    )?;
3835    writeln!(
3836        code,
3837        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3838    )?;
3839    writeln!(
3840        code,
3841        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3842    )?;
3843    writeln!(
3844        code,
3845        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3846    )?;
3847    writeln!(
3848        code,
3849        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3850    )?;
3851    writeln!(
3852        code,
3853        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3854    )?;
3855    writeln!(code)?;
3856
3857    // KV cache buffers
3858    writeln!(code, "        // KV cache buffers (per-layer)")?;
3859    writeln!(
3860        code,
3861        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3862    )?;
3863    writeln!(
3864        code,
3865        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3866    )?;
3867    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3868    writeln!(
3869        code,
3870        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3871    )?;
3872    writeln!(
3873        code,
3874        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3875    )?;
3876    writeln!(code, "        }}")?;
3877    writeln!(code)?;
3878
3879    writeln!(code, "        Self {{")?;
3880    writeln!(code, "            device,")?;
3881    writeln!(code, "            queue,")?;
3882    writeln!(code, "            matmul_pipeline,")?;
3883    writeln!(code, "            matmul_q8_pipeline,")?;
3884    writeln!(code, "            matmul_q4_pipeline,")?;
3885    writeln!(code, "            rms_norm_pipeline,")?;
3886    writeln!(code, "            rope_pipeline,")?;
3887    writeln!(code, "            softmax_pipeline,")?;
3888    writeln!(code, "            silu_mul_pipeline,")?;
3889    writeln!(code, "            silu_mul_fused_pipeline,")?;
3890    writeln!(code, "            add_pipeline,")?;
3891    writeln!(code, "            attention_pipeline,")?;
3892    writeln!(code, "            add_inplace_pipeline,")?;
3893    writeln!(code, "            copy_pipeline,")?;
3894    writeln!(code, "            copy_offset_pipeline,")?;
3895    writeln!(code, "            copy_f32_to_f16_offset_pipeline,")?;
3896    writeln!(code, "            matmul_batch_pipeline,")?;
3897    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3898    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3899    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3900    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3901    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3902    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3903    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3904    if config.qkv_bias {
3905        writeln!(code, "            add_bias_batch_pipeline,")?;
3906    }
3907    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3908    writeln!(code, "            rms_norm_batch_pipeline,")?;
3909    writeln!(code, "            rope_batch_pipeline,")?;
3910    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3911    writeln!(code, "            add_inplace_batch_pipeline,")?;
3912    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3913    writeln!(code, "            attention_batch_pipeline,")?;
3914    writeln!(code, "            attention_flash_batch_pipeline,")?;
3915    writeln!(code, "            attention_mma_flash_batch_pipeline,")?;
3916    writeln!(code, "            copy_kv_batch_pipeline,")?;
3917    writeln!(code, "            rope_qk_batch_pipeline,")?;
3918    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3919    writeln!(code, "            embed_buf,")?;
3920    writeln!(code, "            layers,")?;
3921    writeln!(code, "            norm_buf,")?;
3922    writeln!(code, "            lm_head_buf,")?;
3923    writeln!(code, "            hidden_buf,")?;
3924    writeln!(code, "            residual_buf,")?;
3925    writeln!(code, "            normed_buf,")?;
3926    writeln!(code, "            qkv_buf,")?;
3927    writeln!(code, "            attn_out_buf,")?;
3928    writeln!(code, "            attn_proj_buf,")?;
3929    writeln!(code, "            gate_up_buf,")?;
3930    writeln!(code, "            ffn_hidden_buf,")?;
3931    writeln!(code, "            ffn_out_buf,")?;
3932    writeln!(code, "            add_tmp_buf,")?;
3933    writeln!(code, "            logits_buf,")?;
3934    writeln!(code, "            batch_hidden_buf,")?;
3935    writeln!(code, "            batch_residual_buf,")?;
3936    writeln!(code, "            batch_qkv_buf,")?;
3937    writeln!(code, "            batch_attn_out_buf,")?;
3938    writeln!(code, "            batch_attn_proj_buf,")?;
3939    writeln!(code, "            batch_gate_up_buf,")?;
3940    writeln!(code, "            batch_ffn_hidden_buf,")?;
3941    writeln!(code, "            batch_ffn_out_buf,")?;
3942    writeln!(code, "            batch_tokens_buf,")?;
3943    writeln!(code, "            batch_positions_buf,")?;
3944    writeln!(code, "            k_cache,")?;
3945    writeln!(code, "            v_cache,")?;
3946    writeln!(code, "            pos: 0,")?;
3947    writeln!(code, "            prev_cmd: None,")?;
3948    writeln!(code, "        }}")?;
3949    writeln!(code, "    }}")?;
3950    writeln!(code)?;
3951
3952    // ── forward() ──
3953    writeln!(
3954        code,
3955        "    /// Run the forward pass for a single token at the current position."
3956    )?;
3957    writeln!(code, "    ///")?;
3958    writeln!(
3959        code,
3960        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3961    )?;
3962    writeln!(code, "    ///")?;
3963    writeln!(
3964        code,
3965        "    /// All GPU operations are encoded into a single command buffer and"
3966    )?;
3967    writeln!(
3968        code,
3969        "    /// committed once at the end, avoiding per-operation synchronization."
3970    )?;
3971    writeln!(
3972        code,
3973        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3974    )?;
3975    writeln!(
3976        code,
3977        "        // Wait for any pending prefill command buffer"
3978    )?;
3979    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3980    writeln!(code, "            prev.wait_until_completed();")?;
3981    writeln!(code, "        }}")?;
3982    writeln!(code)?;
3983    writeln!(code, "        let pos = self.pos;")?;
3984    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3985    writeln!(code)?;
3986
3987    // Single compute encoder for the entire forward pass — no blit encoder
3988    // transitions. Copy operations use compute copy kernels instead of blits.
3989    let matmul_fn = if is_q8 {
3990        "dispatch_matmul_q8"
3991    } else if is_q4 {
3992        "dispatch_matmul_q4"
3993    } else {
3994        "dispatch_matmul"
3995    };
3996
3997    writeln!(
3998        code,
3999        "        // Single compute encoder for the entire forward pass (no blit transitions)"
4000    )?;
4001    writeln!(code, "        {{")?;
4002    writeln!(
4003        code,
4004        "            let enc = cmd.new_compute_command_encoder();"
4005    )?;
4006    writeln!(code)?;
4007
4008    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
4009    writeln!(
4010        code,
4011        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
4012    )?;
4013    writeln!(
4014        code,
4015        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
4016    )?;
4017    writeln!(
4018        code,
4019        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
4020    )?;
4021    writeln!(
4022        code,
4023        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
4024        hidden * 4,
4025    )?;
4026    writeln!(code, "            unsafe {{")?;
4027    writeln!(
4028        code,
4029        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
4030    )?;
4031    writeln!(
4032        code,
4033        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4034    )?;
4035    writeln!(
4036        code,
4037        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
4038    )?;
4039    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
4040    writeln!(
4041        code,
4042        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4043    )?;
4044    writeln!(code, "                    hidden_ptr,")?;
4045    writeln!(code, "                    HIDDEN_SIZE,")?;
4046    writeln!(code, "                );")?;
4047    writeln!(
4048        code,
4049        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4050    )?;
4051    writeln!(code, "            }}")?;
4052    writeln!(code)?;
4053
4054    // 2. Transformer layers
4055    writeln!(code, "            // 2. Transformer layers")?;
4056    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4057    writeln!(code)?;
4058    let q_byte_offset = 0usize;
4059    let k_float_offset = hidden;
4060    let v_float_offset = hidden + kv_dim;
4061
4062    writeln!(
4063        code,
4064        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
4065    )?;
4066    writeln!(
4067        code,
4068        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4069    )?;
4070    writeln!(
4071        code,
4072        "                // Fused Q+K+V matmul: single dispatch for all three projections"
4073    )?;
4074    writeln!(
4075        code,
4076        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4077    )?;
4078    if config.qkv_bias {
4079        writeln!(
4080            code,
4081            "                // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4082        )?;
4083        writeln!(
4084            code,
4085            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4086        )?;
4087    }
4088    writeln!(
4089        code,
4090        "                // Fused Q+K RoPE in one dispatch (saves 1 dispatch + barrier vs separate Q and K rope)"
4091    )?;
4092    writeln!(
4093        code,
4094        "                self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4095    )?;
4096    writeln!(code)?;
4097    writeln!(
4098        code,
4099        "                // Fused K+V cache update in one dispatch (f32 qkv_buf -> f16 KV cache)"
4100    )?;
4101    writeln!(
4102        code,
4103        "                self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4104    )?;
4105    writeln!(code)?;
4106    writeln!(
4107        code,
4108        "                // Attention using Q from qkv_buf (offset 0)"
4109    )?;
4110    writeln!(
4111        code,
4112        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4113    )?;
4114    writeln!(
4115        code,
4116        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4117    )?;
4118    writeln!(
4119        code,
4120        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4121    )?;
4122    writeln!(
4123        code,
4124        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4125    )?;
4126    writeln!(
4127        code,
4128        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4129    )?;
4130    writeln!(
4131        code,
4132        "                // Fused gate+up matmul: single dispatch for both projections"
4133    )?;
4134    writeln!(
4135        code,
4136        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4137    )?;
4138    writeln!(
4139        code,
4140        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4141    )?;
4142    writeln!(
4143        code,
4144        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4145    )?;
4146    writeln!(
4147        code,
4148        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4149    )?;
4150    writeln!(code, "            }}")?;
4151    writeln!(code)?;
4152
4153    // 3. Final RMS norm + logits
4154    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
4155    writeln!(
4156        code,
4157        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4158    )?;
4159    writeln!(
4160        code,
4161        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4162    )?;
4163    writeln!(code)?;
4164    writeln!(code, "            enc.end_encoding();")?;
4165    writeln!(code, "        }}")?;
4166    writeln!(code)?;
4167
4168    // 5. Single commit + wait, then read back logits
4169    writeln!(
4170        code,
4171        "        // 5. Commit all GPU work and wait for completion"
4172    )?;
4173    writeln!(code, "        cmd.commit();")?;
4174    writeln!(code, "        cmd.wait_until_completed();")?;
4175    writeln!(code)?;
4176    writeln!(code, "        // 6. Read back logits from GPU")?;
4177    writeln!(code, "        let logits = unsafe {{")?;
4178    writeln!(
4179        code,
4180        "            let ptr = self.logits_buf.contents() as *const f32;"
4181    )?;
4182    writeln!(
4183        code,
4184        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4185    )?;
4186    writeln!(code, "        }};")?;
4187    writeln!(code)?;
4188    writeln!(code, "        self.pos += 1;")?;
4189    writeln!(code, "        logits")?;
4190    writeln!(code, "    }}")?;
4191    writeln!(code)?;
4192
4193    // ── forward_profile: instrumented forward with per-operation timing ──
4194    writeln!(
4195        code,
4196        "    /// Profiling forward pass that prints per-stage GPU timing."
4197    )?;
4198    writeln!(code, "    ///")?;
4199    writeln!(
4200        code,
4201        "    /// Each stage is committed and waited on separately so that GPU timestamps"
4202    )?;
4203    writeln!(
4204        code,
4205        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
4206    )?;
4207    writeln!(
4208        code,
4209        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
4210    )?;
4211    writeln!(
4212        code,
4213        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4214    )?;
4215    writeln!(code, "        use std::time::Instant;")?;
4216    writeln!(code)?;
4217    writeln!(
4218        code,
4219        "        // Wait for any pending prefill command buffer"
4220    )?;
4221    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4222    writeln!(code, "            prev.wait_until_completed();")?;
4223    writeln!(code, "        }}")?;
4224    writeln!(code)?;
4225    writeln!(code, "        let pos = self.pos;")?;
4226    writeln!(code)?;
4227
4228    // Stage: embedding (CPU, no GPU)
4229    writeln!(
4230        code,
4231        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
4232    )?;
4233    writeln!(code, "        let t_embed = Instant::now();")?;
4234    writeln!(code, "        unsafe {{")?;
4235    writeln!(
4236        code,
4237        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
4238    )?;
4239    writeln!(
4240        code,
4241        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4242    )?;
4243    writeln!(
4244        code,
4245        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
4246    )?;
4247    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
4248    writeln!(
4249        code,
4250        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4251    )?;
4252    writeln!(code, "                hidden_ptr,")?;
4253    writeln!(code, "                HIDDEN_SIZE,")?;
4254    writeln!(code, "            );")?;
4255    writeln!(
4256        code,
4257        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4258    )?;
4259    writeln!(code, "        }}")?;
4260    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
4261    writeln!(code)?;
4262
4263    // Stage: Transformer layers (all together on GPU)
4264    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
4265    writeln!(code, "        let t_layers = Instant::now();")?;
4266    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4267    writeln!(code, "        {{")?;
4268    writeln!(
4269        code,
4270        "            let enc = cmd.new_compute_command_encoder();"
4271    )?;
4272    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4273    writeln!(
4274        code,
4275        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4276    )?;
4277    writeln!(
4278        code,
4279        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4280    )?;
4281    if config.qkv_bias {
4282        writeln!(
4283            code,
4284            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4285        )?;
4286    }
4287    writeln!(
4288        code,
4289        "                self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4290    )?;
4291    writeln!(
4292        code,
4293        "                self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4294    )?;
4295    writeln!(
4296        code,
4297        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4298    )?;
4299    writeln!(
4300        code,
4301        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4302    )?;
4303    writeln!(
4304        code,
4305        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4306    )?;
4307    writeln!(
4308        code,
4309        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4310    )?;
4311    writeln!(
4312        code,
4313        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4314    )?;
4315    writeln!(
4316        code,
4317        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4318    )?;
4319    writeln!(
4320        code,
4321        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4322    )?;
4323    writeln!(
4324        code,
4325        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4326    )?;
4327    writeln!(code, "            }}")?;
4328    writeln!(code, "            enc.end_encoding();")?;
4329    writeln!(code, "        }}")?;
4330    writeln!(code, "        cmd.commit();")?;
4331    writeln!(code, "        cmd.wait_until_completed();")?;
4332    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
4333    writeln!(code)?;
4334
4335    // Stage: Final norm + logits
4336    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
4337    writeln!(code, "        let t_logits = Instant::now();")?;
4338    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
4339    writeln!(code, "        {{")?;
4340    writeln!(
4341        code,
4342        "            let enc = cmd2.new_compute_command_encoder();"
4343    )?;
4344    writeln!(
4345        code,
4346        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4347    )?;
4348    writeln!(
4349        code,
4350        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4351    )?;
4352    writeln!(code, "            enc.end_encoding();")?;
4353    writeln!(code, "        }}")?;
4354    writeln!(code, "        cmd2.commit();")?;
4355    writeln!(code, "        cmd2.wait_until_completed();")?;
4356    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
4357    writeln!(code)?;
4358
4359    // Print profile results
4360    writeln!(
4361        code,
4362        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4363    )?;
4364    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
4365    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
4366    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
4367    writeln!(
4368        code,
4369        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4370    )?;
4371    writeln!(code)?;
4372
4373    // Read back logits
4374    writeln!(code, "        let logits = unsafe {{")?;
4375    writeln!(
4376        code,
4377        "            let ptr = self.logits_buf.contents() as *const f32;"
4378    )?;
4379    writeln!(
4380        code,
4381        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4382    )?;
4383    writeln!(code, "        }};")?;
4384    writeln!(code)?;
4385    writeln!(code, "        self.pos += 1;")?;
4386    writeln!(code, "        logits")?;
4387    writeln!(code, "    }}")?;
4388    writeln!(code)?;
4389
4390    // ── forward_prefill: single-token async forward (backward compat) ──
4391    writeln!(
4392        code,
4393        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
4394    )?;
4395    writeln!(code, "    ///")?;
4396    writeln!(
4397        code,
4398        "    /// Commits the command buffer without waiting, enabling double-buffered"
4399    )?;
4400    writeln!(
4401        code,
4402        "    /// execution: GPU processes token N while CPU encodes token N+1."
4403    )?;
4404    writeln!(
4405        code,
4406        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
4407    )?;
4408    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
4409    writeln!(code, "    }}")?;
4410    writeln!(code)?;
4411
4412    // ── forward_prefill_batch: batched prefill for multiple tokens ──
4413    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
4414    let batch_matmul_fn = if is_q8 {
4415        "dispatch_matmul_q8_batch"
4416    } else if is_q4 {
4417        "dispatch_matmul_q4_batch"
4418    } else {
4419        "dispatch_matmul_batch"
4420    };
4421
4422    writeln!(
4423        code,
4424        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4425    )?;
4426    writeln!(code, "    ///")?;
4427    writeln!(
4428        code,
4429        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4430    )?;
4431    writeln!(
4432        code,
4433        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
4434    )?;
4435    writeln!(
4436        code,
4437        "    /// This provides significant speedup during prompt prefill."
4438    )?;
4439    writeln!(
4440        code,
4441        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4442    )?;
4443    writeln!(code, "        if tokens.is_empty() {{ return; }}")?;
4444    writeln!(
4445        code,
4446        "        // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4447    )?;
4448    writeln!(
4449        code,
4450        "        // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4451    )?;
4452    writeln!(
4453        code,
4454        "        // longer than that must be processed iteratively.  The KV cache"
4455    )?;
4456    writeln!(code, "        // carries state across chunks via self.pos.")?;
4457    writeln!(
4458        code,
4459        "        for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4460    )?;
4461    writeln!(code, "        let m = chunk.len();")?;
4462    writeln!(code, "        if m == 0 {{ continue; }}")?;
4463    writeln!(code, "        let start_pos = self.pos;")?;
4464    writeln!(code)?;
4465    writeln!(code, "        // Wait for any pending command buffer")?;
4466    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4467    writeln!(code, "            prev.wait_until_completed();")?;
4468    writeln!(code, "        }}")?;
4469    writeln!(code)?;
4470
4471    // Upload token IDs and positions to GPU
4472    writeln!(
4473        code,
4474        "        // Upload token IDs and positions to GPU buffers"
4475    )?;
4476    writeln!(code, "        unsafe {{")?;
4477    writeln!(
4478        code,
4479        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4480    )?;
4481    writeln!(
4482        code,
4483        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4484    )?;
4485    writeln!(code, "            for i in 0..m {{")?;
4486    writeln!(code, "                *tok_ptr.add(i) = chunk[i];")?;
4487    writeln!(
4488        code,
4489        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
4490    )?;
4491    writeln!(code, "            }}")?;
4492    writeln!(code, "        }}")?;
4493    writeln!(code)?;
4494
4495    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4496    writeln!(code, "        {{")?;
4497    writeln!(
4498        code,
4499        "            let enc = cmd.new_compute_command_encoder();"
4500    )?;
4501    writeln!(code)?;
4502
4503    // 1. Batch embedding lookup
4504    writeln!(
4505        code,
4506        "            // 1. Batch embedding lookup: copy all token embeddings at once"
4507    )?;
4508    writeln!(
4509        code,
4510        "            self.dispatch_copy_embedding_batch(&enc, m);"
4511    )?;
4512    // Copy batch_hidden -> batch_residual
4513    writeln!(
4514        code,
4515        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4516    )?;
4517    writeln!(code)?;
4518
4519    // 2. Transformer layers
4520    writeln!(code, "            // 2. Transformer layers")?;
4521    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4522    writeln!(code)?;
4523
4524    // Batch RMS norm: residual -> hidden (batched)
4525    writeln!(
4526        code,
4527        "                // Batch RMS norm: batch_residual -> batch_hidden"
4528    )?;
4529    writeln!(
4530        code,
4531        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4532    )?;
4533
4534    // Batch QKV matmul
4535    writeln!(
4536        code,
4537        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4538    )?;
4539    writeln!(
4540        code,
4541        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4542    )?;
4543    if config.qkv_bias {
4544        writeln!(
4545            code,
4546            "                // Qwen2: broadcast-add QKV bias across all M tokens."
4547        )?;
4548        writeln!(
4549            code,
4550            "                self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4551        )?;
4552    }
4553    writeln!(code)?;
4554
4555    // Fused RoPE on Q+K portions in a single dispatch
4556    let k_float_offset = hidden;
4557    writeln!(
4558        code,
4559        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4560    )?;
4561    writeln!(
4562        code,
4563        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4564    )?;
4565    writeln!(code)?;
4566
4567    // Fused KV cache update: copy both K and V in a single dispatch
4568    let v_float_offset = hidden + kv_dim;
4569    writeln!(
4570        code,
4571        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4572    )?;
4573    writeln!(
4574        code,
4575        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4576    )?;
4577    writeln!(code)?;
4578
4579    // Batched causal attention: ONE dispatch for all M tokens
4580    writeln!(
4581        code,
4582        "                // Batched causal attention: one dispatch for all M tokens"
4583    )?;
4584    writeln!(
4585        code,
4586        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4587    )?;
4588    writeln!(code)?;
4589
4590    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4591    writeln!(code, "                // Batched O projection")?;
4592    writeln!(
4593        code,
4594        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4595    )?;
4596    writeln!(code)?;
4597
4598    // Batch add: residual += attn_proj for all tokens
4599    writeln!(
4600        code,
4601        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4602    )?;
4603    writeln!(code)?;
4604
4605    // Batch FFN
4606    writeln!(
4607        code,
4608        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4609    )?;
4610    writeln!(
4611        code,
4612        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4613    )?;
4614    writeln!(
4615        code,
4616        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4617    )?;
4618    writeln!(
4619        code,
4620        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4621    )?;
4622    writeln!(
4623        code,
4624        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4625    )?;
4626    writeln!(
4627        code,
4628        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4629    )?;
4630    writeln!(code, "            }}")?;
4631    writeln!(code)?;
4632
4633    // Copy last token's residual to single-token residual_buf for next forward() call
4634    writeln!(
4635        code,
4636        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4637    )?;
4638    writeln!(
4639        code,
4640        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4641    )?;
4642    writeln!(code)?;
4643    writeln!(code, "            enc.end_encoding();")?;
4644    writeln!(code, "        }}")?;
4645    writeln!(code)?;
4646
4647    writeln!(code, "        cmd.commit();")?;
4648    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4649    writeln!(code, "        self.pos += m;")?;
4650    writeln!(code, "        }}  // end for chunk")?;
4651    writeln!(code, "    }}")?;
4652    writeln!(code)?;
4653
4654    // ── reset() — rewind KV cache position for new inference requests ──
4655    writeln!(
4656        code,
4657        "    /// Reset the model state for a new inference request."
4658    )?;
4659    writeln!(code, "    pub fn reset(&mut self) {{")?;
4660    writeln!(code, "        self.pos = 0;")?;
4661    writeln!(code, "        self.prev_cmd = None;")?;
4662    writeln!(code, "    }}")?;
4663    writeln!(code)?;
4664
4665    // ── Private dispatch helpers (all take a shared compute encoder) ──
4666    writeln!(
4667        code,
4668        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4669    )?;
4670    writeln!(
4671        code,
4672        "    // These methods set pipeline state + buffers + dispatch on an existing"
4673    )?;
4674    writeln!(
4675        code,
4676        "    // encoder, avoiding per-operation encoder creation overhead."
4677    )?;
4678    writeln!(code)?;
4679
4680    // dispatch_rms_norm
4681    writeln!(
4682        code,
4683        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4684    )?;
4685    writeln!(
4686        code,
4687        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4688    )?;
4689    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4690    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4691    writeln!(
4692        code,
4693        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4694    )?;
4695    writeln!(
4696        code,
4697        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4698    )?;
4699    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4700    writeln!(
4701        code,
4702        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4703    )?;
4704    writeln!(
4705        code,
4706        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4707    )?;
4708    writeln!(
4709        code,
4710        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4711    )?;
4712    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4713    writeln!(
4714        code,
4715        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4716    )?;
4717    writeln!(
4718        code,
4719        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4720    )?;
4721    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4722    writeln!(code, "    }}")?;
4723    writeln!(code)?;
4724
4725    // dispatch_matmul
4726    writeln!(
4727        code,
4728        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4729    )?;
4730    writeln!(
4731        code,
4732        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4733    )?;
4734    writeln!(code, "        let r: u32 = rows as u32;")?;
4735    writeln!(code, "        let c: u32 = cols as u32;")?;
4736    writeln!(
4737        code,
4738        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4739    )?;
4740    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4741    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4742    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4743    writeln!(
4744        code,
4745        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4746    )?;
4747    writeln!(
4748        code,
4749        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4750    )?;
4751    writeln!(
4752        code,
4753        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4754    )?;
4755    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4756    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4757    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4758    writeln!(
4759        code,
4760        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4761    )?;
4762    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4763    writeln!(code, "    }}")?;
4764    writeln!(code)?;
4765
4766    // dispatch_matmul_q8
4767    writeln!(
4768        code,
4769        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4770    )?;
4771    writeln!(
4772        code,
4773        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4774    )?;
4775    writeln!(
4776        code,
4777        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4778    )?;
4779    writeln!(code, "        let r: u32 = rows as u32;")?;
4780    writeln!(code, "        let c: u32 = cols as u32;")?;
4781    writeln!(
4782        code,
4783        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4784    )?;
4785    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4786    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4787    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4788    writeln!(
4789        code,
4790        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4791    )?;
4792    writeln!(
4793        code,
4794        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4795    )?;
4796    writeln!(
4797        code,
4798        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4799    )?;
4800    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4801    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4802    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4803    writeln!(
4804        code,
4805        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4806    )?;
4807    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4808    writeln!(code, "    }}")?;
4809    writeln!(code)?;
4810
4811    // dispatch_matmul_q4
4812    writeln!(
4813        code,
4814        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4815    )?;
4816    writeln!(
4817        code,
4818        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4819    )?;
4820    writeln!(
4821        code,
4822        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4823    )?;
4824    writeln!(code, "        let r: u32 = rows as u32;")?;
4825    writeln!(code, "        let c: u32 = cols as u32;")?;
4826    writeln!(
4827        code,
4828        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4829    )?;
4830    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4831    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4832    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4833    writeln!(
4834        code,
4835        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4836    )?;
4837    writeln!(
4838        code,
4839        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4840    )?;
4841    writeln!(
4842        code,
4843        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4844    )?;
4845    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4846    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4847    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4848    writeln!(
4849        code,
4850        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4851    )?;
4852    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4853    writeln!(code, "    }}")?;
4854    writeln!(code)?;
4855
4856    // dispatch_rope
4857    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4858    writeln!(
4859        code,
4860        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4861    )?;
4862    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4863    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4864    writeln!(code, "        let p: u32 = pos as u32;")?;
4865    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4866    writeln!(
4867        code,
4868        "        let total_pairs = num_heads * (head_dim / 2);"
4869    )?;
4870    writeln!(
4871        code,
4872        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4873    )?;
4874    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4875    writeln!(
4876        code,
4877        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4878    )?;
4879    writeln!(
4880        code,
4881        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4882    )?;
4883    writeln!(
4884        code,
4885        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4886    )?;
4887    writeln!(
4888        code,
4889        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4890    )?;
4891    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4892    writeln!(
4893        code,
4894        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4895    )?;
4896    writeln!(
4897        code,
4898        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4899    )?;
4900    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4901    writeln!(code, "    }}")?;
4902    writeln!(code)?;
4903
4904    // dispatch_rope_offset
4905    writeln!(
4906        code,
4907        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4908    )?;
4909    writeln!(
4910        code,
4911        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4912    )?;
4913    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4914    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4915    writeln!(code, "        let p: u32 = pos as u32;")?;
4916    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4917    writeln!(
4918        code,
4919        "        let total_pairs = num_heads * (head_dim / 2);"
4920    )?;
4921    writeln!(
4922        code,
4923        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4924    )?;
4925    writeln!(
4926        code,
4927        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4928    )?;
4929    writeln!(
4930        code,
4931        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4932    )?;
4933    writeln!(
4934        code,
4935        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4936    )?;
4937    writeln!(
4938        code,
4939        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4940    )?;
4941    writeln!(
4942        code,
4943        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4944    )?;
4945    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4946    writeln!(
4947        code,
4948        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4949    )?;
4950    writeln!(
4951        code,
4952        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4953    )?;
4954    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4955    writeln!(code, "    }}")?;
4956    writeln!(code)?;
4957
4958    // dispatch_attention
4959    writeln!(
4960        code,
4961        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4962    )?;
4963    writeln!(
4964        code,
4965        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4966    )?;
4967    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4968    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4969    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4970    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4971    writeln!(
4972        code,
4973        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4974    )?;
4975    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4976    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4977    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4978    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4979    writeln!(
4980        code,
4981        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4982    )?;
4983    writeln!(
4984        code,
4985        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4986    )?;
4987    writeln!(
4988        code,
4989        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4990    )?;
4991    writeln!(
4992        code,
4993        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4994    )?;
4995    writeln!(
4996        code,
4997        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4998    )?;
4999    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5000    writeln!(
5001        code,
5002        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5003    )?;
5004    writeln!(
5005        code,
5006        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5007    )?;
5008    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5009    writeln!(code, "    }}")?;
5010    writeln!(code)?;
5011
5012    // dispatch_attention_offset
5013    writeln!(
5014        code,
5015        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
5016    )?;
5017    writeln!(
5018        code,
5019        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
5020    )?;
5021    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
5022    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5023    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5024    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5025    writeln!(
5026        code,
5027        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
5028    )?;
5029    writeln!(
5030        code,
5031        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5032    )?;
5033    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5034    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5035    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5036    writeln!(
5037        code,
5038        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5039    )?;
5040    writeln!(
5041        code,
5042        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5043    )?;
5044    writeln!(
5045        code,
5046        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5047    )?;
5048    writeln!(
5049        code,
5050        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5051    )?;
5052    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5053    writeln!(
5054        code,
5055        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5056    )?;
5057    writeln!(
5058        code,
5059        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5060    )?;
5061    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5062    writeln!(code, "    }}")?;
5063    writeln!(code)?;
5064
5065    // dispatch_silu_mul
5066    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
5067    writeln!(
5068        code,
5069        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5070    )?;
5071    writeln!(code, "        let count: u32 = n as u32;")?;
5072    writeln!(
5073        code,
5074        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5075    )?;
5076    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
5077    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
5078    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5079    writeln!(
5080        code,
5081        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5082    )?;
5083    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5084    writeln!(
5085        code,
5086        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5087    )?;
5088    writeln!(
5089        code,
5090        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5091    )?;
5092    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5093    writeln!(code, "    }}")?;
5094    writeln!(code)?;
5095
5096    // dispatch_silu_mul_fused
5097    writeln!(
5098        code,
5099        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5100    )?;
5101    writeln!(
5102        code,
5103        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
5104    )?;
5105    writeln!(
5106        code,
5107        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5108    )?;
5109    writeln!(code, "        let count: u32 = n as u32;")?;
5110    writeln!(
5111        code,
5112        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5113    )?;
5114    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5115    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5116    writeln!(
5117        code,
5118        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5119    )?;
5120    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5121    writeln!(
5122        code,
5123        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5124    )?;
5125    writeln!(
5126        code,
5127        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5128    )?;
5129    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5130    writeln!(code, "    }}")?;
5131    writeln!(code)?;
5132
5133    // dispatch_copy (simple src -> dst copy via compute kernel)
5134    writeln!(
5135        code,
5136        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5137    )?;
5138    writeln!(
5139        code,
5140        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5141    )?;
5142    writeln!(code, "        let n: u32 = count as u32;")?;
5143    writeln!(
5144        code,
5145        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5146    )?;
5147    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5148    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5149    writeln!(
5150        code,
5151        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5152    )?;
5153    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5154    writeln!(
5155        code,
5156        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5157    )?;
5158    writeln!(
5159        code,
5160        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5161    )?;
5162    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5163    writeln!(code, "    }}")?;
5164    writeln!(code)?;
5165
5166    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
5167    writeln!(
5168        code,
5169        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5170    )?;
5171    writeln!(
5172        code,
5173        "    /// Used for embedding table lookup (copy a specific row)."
5174    )?;
5175    writeln!(
5176        code,
5177        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5178    )?;
5179    writeln!(code, "        let off: u32 = src_offset as u32;")?;
5180    writeln!(code, "        let n: u32 = count as u32;")?;
5181    writeln!(
5182        code,
5183        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5184    )?;
5185    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5186    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5187    writeln!(
5188        code,
5189        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5190    )?;
5191    writeln!(
5192        code,
5193        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5194    )?;
5195    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5196    writeln!(
5197        code,
5198        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5199    )?;
5200    writeln!(
5201        code,
5202        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5203    )?;
5204    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5205    writeln!(code, "    }}")?;
5206    writeln!(code)?;
5207
5208    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
5209    writeln!(
5210        code,
5211        "    /// Dispatch copy from source at byte offset to destination at float offset."
5212    )?;
5213    writeln!(
5214        code,
5215        "    /// Used for KV cache updates from fused QKV buffer."
5216    )?;
5217    writeln!(
5218        code,
5219        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5220    )?;
5221    writeln!(code, "        let n: u32 = count as u32;")?;
5222    writeln!(
5223        code,
5224        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5225    )?;
5226    writeln!(
5227        code,
5228        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5229    )?;
5230    writeln!(
5231        code,
5232        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5233    )?;
5234    writeln!(
5235        code,
5236        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5237    )?;
5238    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5239    writeln!(
5240        code,
5241        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5242    )?;
5243    writeln!(
5244        code,
5245        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5246    )?;
5247    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5248    writeln!(code, "    }}")?;
5249    writeln!(code)?;
5250
5251    // dispatch_copy_from_offset_f16 (copy src[src_byte_offset..] -> f16 dst[dst_elem_offset..])
5252    writeln!(
5253        code,
5254        "    /// Dispatch f32 -> f16 copy from src at byte offset to half-typed dst at element offset."
5255    )?;
5256    writeln!(
5257        code,
5258        "    /// Used for single-token decode KV cache updates (f32 QKV buf -> f16 KV cache)."
5259    )?;
5260    writeln!(
5261        code,
5262        "    fn dispatch_copy_from_offset_f16(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_elem_offset: usize, count: usize) {{"
5263    )?;
5264    writeln!(code, "        let n: u32 = count as u32;")?;
5265    writeln!(
5266        code,
5267        "        enc.set_compute_pipeline_state(&self.copy_f32_to_f16_offset_pipeline);"
5268    )?;
5269    writeln!(
5270        code,
5271        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5272    )?;
5273    writeln!(
5274        code,
5275        "        enc.set_buffer(1, Some(dst), (dst_elem_offset * 2) as u64);"
5276    )?;
5277    writeln!(
5278        code,
5279        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5280    )?;
5281    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5282    writeln!(
5283        code,
5284        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5285    )?;
5286    writeln!(
5287        code,
5288        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5289    )?;
5290    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5291    writeln!(code, "    }}")?;
5292    writeln!(code)?;
5293
5294    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
5295    writeln!(
5296        code,
5297        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5298    )?;
5299    writeln!(
5300        code,
5301        "    /// Used for KV cache updates (write to a specific position in the cache)."
5302    )?;
5303    writeln!(
5304        code,
5305        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5306    )?;
5307    writeln!(code, "        let n: u32 = count as u32;")?;
5308    writeln!(
5309        code,
5310        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5311    )?;
5312    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5313    writeln!(
5314        code,
5315        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5316    )?;
5317    writeln!(
5318        code,
5319        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5320    )?;
5321    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5322    writeln!(
5323        code,
5324        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5325    )?;
5326    writeln!(
5327        code,
5328        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5329    )?;
5330    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5331    writeln!(code, "    }}")?;
5332    writeln!(code)?;
5333
5334    // dispatch_add_inplace (residual connection, no blit needed)
5335    writeln!(
5336        code,
5337        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5338    )?;
5339    writeln!(
5340        code,
5341        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5342    )?;
5343    writeln!(code, "        let count: u32 = n as u32;")?;
5344    writeln!(
5345        code,
5346        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5347    )?;
5348    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5349    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5350    writeln!(
5351        code,
5352        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5353    )?;
5354    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5355    writeln!(
5356        code,
5357        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5358    )?;
5359    writeln!(
5360        code,
5361        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5362    )?;
5363    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5364    writeln!(code, "    }}")?;
5365    writeln!(code)?;
5366
5367    // ── Batched prefill dispatch helpers ──
5368    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
5369    writeln!(code)?;
5370
5371    // dispatch_copy_embedding_batch
5372    writeln!(
5373        code,
5374        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
5375    )?;
5376    writeln!(
5377        code,
5378        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5379    )?;
5380    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
5381    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5382    writeln!(
5383        code,
5384        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5385    )?;
5386    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5387    writeln!(
5388        code,
5389        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5390    )?;
5391    writeln!(
5392        code,
5393        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5394    )?;
5395    writeln!(
5396        code,
5397        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5398    )?;
5399    writeln!(
5400        code,
5401        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5402    )?;
5403    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
5404    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5405    writeln!(
5406        code,
5407        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5408    )?;
5409    writeln!(
5410        code,
5411        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5412    )?;
5413    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5414    writeln!(code, "    }}")?;
5415    writeln!(code)?;
5416
5417    // dispatch_rms_norm_batch
5418    writeln!(
5419        code,
5420        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
5421    )?;
5422    writeln!(
5423        code,
5424        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5425    )?;
5426    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
5427    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
5428    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5429    writeln!(
5430        code,
5431        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5432    )?;
5433    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
5434    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
5435    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5436    writeln!(
5437        code,
5438        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5439    )?;
5440    writeln!(
5441        code,
5442        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5443    )?;
5444    writeln!(
5445        code,
5446        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5447    )?;
5448    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5449    writeln!(
5450        code,
5451        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5452    )?;
5453    writeln!(
5454        code,
5455        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5456    )?;
5457    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5458    writeln!(code, "    }}")?;
5459    writeln!(code)?;
5460
5461    // dispatch_matmul_batch (f32)
5462    writeln!(
5463        code,
5464        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5465    )?;
5466    writeln!(
5467        code,
5468        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5469    )?;
5470    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5471    writeln!(code, "        let r: u32 = rows as u32;")?;
5472    writeln!(code, "        let c: u32 = cols as u32;")?;
5473    writeln!(
5474        code,
5475        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5476    )?;
5477    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5478    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5479    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5480    writeln!(
5481        code,
5482        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5483    )?;
5484    writeln!(
5485        code,
5486        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5487    )?;
5488    writeln!(
5489        code,
5490        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5491    )?;
5492    writeln!(
5493        code,
5494        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
5495    )?;
5496    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5497    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5498    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5499    writeln!(
5500        code,
5501        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5502    )?;
5503    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5504    writeln!(code, "    }}")?;
5505    writeln!(code)?;
5506
5507    // dispatch_matmul_q8_batch
5508    writeln!(
5509        code,
5510        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5511    )?;
5512    writeln!(code, "    ///")?;
5513    writeln!(
5514        code,
5515        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5516    )?;
5517    writeln!(
5518        code,
5519        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5520    )?;
5521    writeln!(
5522        code,
5523        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5524    )?;
5525    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5526    writeln!(code, "        let r: u32 = rows as u32;")?;
5527    writeln!(code, "        let c: u32 = cols as u32;")?;
5528    writeln!(
5529        code,
5530        "        // Tile sizes must match the Metal shader constants."
5531    )?;
5532    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
5533    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
5534    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
5535    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
5536    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
5537    writeln!(
5538        code,
5539        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
5540    )?;
5541    writeln!(
5542        code,
5543        "        // Prefer the large 32×32 tile when the problem supports it — halves"
5544    )?;
5545    writeln!(
5546        code,
5547        "        // dispatch count and reuses each weight load across 32 tokens."
5548    )?;
5549    writeln!(
5550        code,
5551        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5552    )?;
5553    writeln!(
5554        code,
5555        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5556    )?;
5557    writeln!(
5558        code,
5559        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
5560    )?;
5561    writeln!(
5562        code,
5563        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5564    )?;
5565    writeln!(
5566        code,
5567        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
5568    )?;
5569    writeln!(
5570        code,
5571        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5572    )?;
5573    writeln!(
5574        code,
5575        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5576    )?;
5577    writeln!(
5578        code,
5579        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5580    )?;
5581    writeln!(
5582        code,
5583        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5584    )?;
5585    writeln!(code, "            let use_h4 = cols >= 2048;")?;
5586    writeln!(code, "            let pipe = if use_h4 {{")?;
5587    writeln!(code, "                if num_tokens >= 256 {{")?;
5588    writeln!(
5589        code,
5590        "                    &self.matmul_q8_mma32_hh4_pipeline"
5591    )?;
5592    writeln!(code, "                }} else {{")?;
5593    writeln!(
5594        code,
5595        "                    &self.matmul_q8_mma32_h4_pipeline"
5596    )?;
5597    writeln!(code, "                }}")?;
5598    writeln!(code, "            }} else {{")?;
5599    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
5600    writeln!(code, "            }};")?;
5601    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
5602    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5603    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5604    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5605    writeln!(
5606        code,
5607        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5608    )?;
5609    writeln!(
5610        code,
5611        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5612    )?;
5613    writeln!(
5614        code,
5615        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5616    )?;
5617    writeln!(code, "            let row_tgs = rows / MMA32_ROW_TILE;")?;
5618    writeln!(
5619        code,
5620        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5621    )?;
5622    writeln!(
5623        code,
5624        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5625    )?;
5626    writeln!(
5627        code,
5628        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5629    )?;
5630    writeln!(
5631        code,
5632        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5633    )?;
5634    writeln!(
5635        code,
5636        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5637    )?;
5638    writeln!(
5639        code,
5640        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5641    )?;
5642    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5643    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5644    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5645    writeln!(
5646        code,
5647        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5648    )?;
5649    writeln!(
5650        code,
5651        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5652    )?;
5653    writeln!(
5654        code,
5655        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5656    )?;
5657    writeln!(code, "            let row_tgs = rows / MMA_ROW_TILE;")?;
5658    writeln!(
5659        code,
5660        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5661    )?;
5662    writeln!(
5663        code,
5664        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5665    )?;
5666    writeln!(
5667        code,
5668        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5669    )?;
5670    writeln!(
5671        code,
5672        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5673    )?;
5674    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5675    writeln!(
5676        code,
5677        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5678    )?;
5679    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5680    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5681    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5682    writeln!(
5683        code,
5684        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5685    )?;
5686    writeln!(
5687        code,
5688        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5689    )?;
5690    writeln!(
5691        code,
5692        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5693    )?;
5694    writeln!(
5695        code,
5696        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5697    )?;
5698    writeln!(
5699        code,
5700        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5701    )?;
5702    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5703    writeln!(
5704        code,
5705        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5706    )?;
5707    writeln!(
5708        code,
5709        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5710    )?;
5711    writeln!(code, "        }} else {{")?;
5712    writeln!(
5713        code,
5714        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5715    )?;
5716    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5717    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5718    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5719    writeln!(
5720        code,
5721        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5722    )?;
5723    writeln!(
5724        code,
5725        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5726    )?;
5727    writeln!(
5728        code,
5729        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5730    )?;
5731    writeln!(
5732        code,
5733        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5734    )?;
5735    writeln!(
5736        code,
5737        "            let num_tg = (row_tgs * num_tokens) as u64;"
5738    )?;
5739    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5740    writeln!(
5741        code,
5742        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5743    )?;
5744    writeln!(
5745        code,
5746        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5747    )?;
5748    writeln!(code, "        }}")?;
5749    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5750    writeln!(code, "    }}")?;
5751    writeln!(code)?;
5752
5753    // dispatch_matmul_q4_batch
5754    writeln!(
5755        code,
5756        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5757    )?;
5758    writeln!(
5759        code,
5760        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5761    )?;
5762    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5763    writeln!(code, "        let r: u32 = rows as u32;")?;
5764    writeln!(code, "        let c: u32 = cols as u32;")?;
5765    writeln!(
5766        code,
5767        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5768    )?;
5769    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5770    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5771    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5772    writeln!(
5773        code,
5774        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5775    )?;
5776    writeln!(
5777        code,
5778        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5779    )?;
5780    writeln!(
5781        code,
5782        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5783    )?;
5784    writeln!(
5785        code,
5786        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5787    )?;
5788    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5789    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5790    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5791    writeln!(
5792        code,
5793        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5794    )?;
5795    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5796    writeln!(code, "    }}")?;
5797    writeln!(code)?;
5798
5799    // dispatch_add_bias_batch — Qwen2 QKV bias broadcast-add after fused qkv matmul.
5800    if config.qkv_bias {
5801        writeln!(
5802            code,
5803            "    /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5804        )?;
5805        writeln!(
5806            code,
5807            "    fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5808        )?;
5809        writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5810        writeln!(code, "        let r: u32 = rows as u32;")?;
5811        writeln!(
5812            code,
5813            "        enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5814        )?;
5815        writeln!(code, "        enc.set_buffer(0, Some(out), 0);")?;
5816        writeln!(code, "        enc.set_buffer(1, Some(bias), 0);")?;
5817        writeln!(
5818            code,
5819            "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5820        )?;
5821        writeln!(
5822            code,
5823            "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5824        )?;
5825        writeln!(code, "        let total = (num_tokens * rows) as u64;")?;
5826        writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5827        writeln!(
5828            code,
5829            "        let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5830        )?;
5831        writeln!(
5832            code,
5833            "        enc.dispatch_thread_groups(grid_size, tg_size);"
5834        )?;
5835        writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5836        writeln!(code, "    }}")?;
5837        writeln!(code)?;
5838    }
5839
5840    // dispatch_rope_batch
5841    writeln!(
5842        code,
5843        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5844    )?;
5845    writeln!(
5846        code,
5847        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5848    )?;
5849    writeln!(
5850        code,
5851        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5852    )?;
5853    writeln!(
5854        code,
5855        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5856    )?;
5857    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5858    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5859    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5860    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5861    writeln!(
5862        code,
5863        "        let pairs_per_token = num_heads * (head_dim / 2);"
5864    )?;
5865    writeln!(
5866        code,
5867        "        let total_pairs = num_tokens * pairs_per_token;"
5868    )?;
5869    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5870    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5871    // we need to pass the buffer at the right byte offset for each token's data.
5872    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5873    // but our layout is data[token * row_stride + data_float_offset + ...].
5874    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5875    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5876    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5877    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5878    //
5879    // Simplest approach: use the single-token rope kernel for each token in a loop.
5880    // This is still efficient because we're dispatching all within the same command encoder.
5881    writeln!(
5882        code,
5883        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5884    )?;
5885    writeln!(code, "        for t in 0..num_tokens {{")?;
5886    writeln!(
5887        code,
5888        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5889    )?;
5890    writeln!(
5891        code,
5892        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5893    )?;
5894    writeln!(
5895        code,
5896        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5897    )?;
5898    writeln!(
5899        code,
5900        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5901    )?;
5902    writeln!(
5903        code,
5904        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5905    )?;
5906    writeln!(
5907        code,
5908        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5909    )?;
5910    writeln!(
5911        code,
5912        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5913    )?;
5914    writeln!(
5915        code,
5916        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5917    )?;
5918    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5919    writeln!(
5920        code,
5921        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5922    )?;
5923    writeln!(
5924        code,
5925        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5926    )?;
5927    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5928    writeln!(code, "        }}")?;
5929    writeln!(code, "    }}")?;
5930    writeln!(code)?;
5931
5932    // dispatch_silu_mul_fused_batch
5933    writeln!(
5934        code,
5935        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5936    )?;
5937    writeln!(
5938        code,
5939        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5940    )?;
5941    writeln!(code, "        let count: u32 = n as u32;")?;
5942    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5943    writeln!(
5944        code,
5945        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5946    )?;
5947    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5948    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5949    writeln!(
5950        code,
5951        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5952    )?;
5953    writeln!(
5954        code,
5955        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5956    )?;
5957    writeln!(code, "        let total = num_tokens * n;")?;
5958    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5959    writeln!(
5960        code,
5961        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5962    )?;
5963    writeln!(
5964        code,
5965        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5966    )?;
5967    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5968    writeln!(code, "    }}")?;
5969    writeln!(code)?;
5970
5971    // dispatch_add_inplace_batch_n (add n elements in-place)
5972    writeln!(
5973        code,
5974        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5975    )?;
5976    writeln!(
5977        code,
5978        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5979    )?;
5980    writeln!(code, "        let count: u32 = total_n as u32;")?;
5981    writeln!(
5982        code,
5983        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5984    )?;
5985    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5986    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5987    writeln!(
5988        code,
5989        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5990    )?;
5991    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5992    writeln!(
5993        code,
5994        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5995    )?;
5996    writeln!(
5997        code,
5998        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5999    )?;
6000    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6001    writeln!(code, "    }}")?;
6002    writeln!(code)?;
6003
6004    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
6005    writeln!(
6006        code,
6007        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
6008    )?;
6009    writeln!(
6010        code,
6011        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
6012    )?;
6013    writeln!(code, "        let n: u32 = count as u32;")?;
6014    writeln!(
6015        code,
6016        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6017    )?;
6018    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6019    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
6020    writeln!(
6021        code,
6022        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6023    )?;
6024    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6025    writeln!(
6026        code,
6027        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6028    )?;
6029    writeln!(
6030        code,
6031        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6032    )?;
6033    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6034    writeln!(code, "    }}")?;
6035    writeln!(code)?;
6036
6037    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
6038    writeln!(
6039        code,
6040        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
6041    )?;
6042    writeln!(
6043        code,
6044        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6045    )?;
6046    writeln!(code, "        let n: u32 = count as u32;")?;
6047    writeln!(
6048        code,
6049        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6050    )?;
6051    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6052    writeln!(
6053        code,
6054        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6055    )?;
6056    writeln!(
6057        code,
6058        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6059    )?;
6060    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6061    writeln!(
6062        code,
6063        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6064    )?;
6065    writeln!(
6066        code,
6067        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6068    )?;
6069    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6070    writeln!(code, "    }}")?;
6071    writeln!(code)?;
6072
6073    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
6074    writeln!(
6075        code,
6076        "    /// Copy from src at byte offset to dst at float offset."
6077    )?;
6078    writeln!(
6079        code,
6080        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6081    )?;
6082    writeln!(code, "        let n: u32 = count as u32;")?;
6083    writeln!(
6084        code,
6085        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6086    )?;
6087    writeln!(
6088        code,
6089        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6090    )?;
6091    writeln!(
6092        code,
6093        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6094    )?;
6095    writeln!(
6096        code,
6097        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6098    )?;
6099    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6100    writeln!(
6101        code,
6102        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6103    )?;
6104    writeln!(
6105        code,
6106        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6107    )?;
6108    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6109    writeln!(code, "    }}")?;
6110    writeln!(code)?;
6111
6112    // dispatch_copy_kv_batch
6113    writeln!(
6114        code,
6115        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6116    )?;
6117    writeln!(
6118        code,
6119        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6120    )?;
6121    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6122    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6123    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6124    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6125    writeln!(code, "        let so: u32 = src_offset as u32;")?;
6126    writeln!(
6127        code,
6128        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6129    )?;
6130    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6131    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
6132    writeln!(
6133        code,
6134        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6135    )?;
6136    writeln!(
6137        code,
6138        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6139    )?;
6140    writeln!(
6141        code,
6142        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6143    )?;
6144    writeln!(
6145        code,
6146        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6147    )?;
6148    writeln!(
6149        code,
6150        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6151    )?;
6152    writeln!(code, "        let total = num_tokens * kv_dim;")?;
6153    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6154    writeln!(
6155        code,
6156        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6157    )?;
6158    writeln!(
6159        code,
6160        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6161    )?;
6162    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6163    writeln!(code, "    }}")?;
6164    writeln!(code)?;
6165
6166    // dispatch_attention_batch
6167    writeln!(
6168        code,
6169        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
6170    )?;
6171    writeln!(
6172        code,
6173        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6174    )?;
6175    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6176    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6177    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
6178    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6179    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6180    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
6181    // Attention kernel selection:
6182    //   * Legacy `attention_batch` materializes scores[4096] in threadgroup memory
6183    //     and uses scalar simdgroup reductions.  Fast at short seq_len, no MMA.
6184    //   * `attention_flash_batch` streams K/V with online softmax; no seq cap,
6185    //     scalar math, ~7-14 % slower than legacy at long contexts (no MMA).
6186    //   * `attention_mma_flash_batch` adds hardware simdgroup_matrix<half, 8, 8>
6187    //     MMA for both Q·K^T and P·V, processing Q_BLOCK=8 tokens per TG.
6188    //     Default path when HEAD_DIM ≤ 128 and num_tokens ≥ 8 (verified on
6189    //     Llama, Qwen2.5, Phi-3).  Set FORGE_MMA_ATTN=0 to force legacy.
6190    writeln!(code, "        let max_seq = base_pos + num_tokens;")?;
6191    writeln!(code, "        let _ = max_seq;")?;
6192    writeln!(
6193        code,
6194        "        let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6195    )?;
6196    writeln!(code, "            .map(|v| v == \"0\").unwrap_or(false);")?;
6197    writeln!(
6198        code,
6199        "        let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6200    )?;
6201    writeln!(code, "        if use_mma_flash {{")?;
6202    writeln!(
6203        code,
6204        "            let pipe = &self.attention_mma_flash_batch_pipeline;"
6205    )?;
6206    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
6207    writeln!(code, "            enc.set_buffer(0, Some(q_buf), 0);")?;
6208    writeln!(code, "            enc.set_buffer(1, Some(k_cache), 0);")?;
6209    writeln!(code, "            enc.set_buffer(2, Some(v_cache), 0);")?;
6210    writeln!(code, "            enc.set_buffer(3, Some(output), 0);")?;
6211    writeln!(
6212        code,
6213        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6214    )?;
6215    writeln!(
6216        code,
6217        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6218    )?;
6219    writeln!(
6220        code,
6221        "            enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6222    )?;
6223    writeln!(
6224        code,
6225        "            enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6226    )?;
6227    writeln!(
6228        code,
6229        "            enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6230    )?;
6231    writeln!(
6232        code,
6233        "            enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6234    )?;
6235    writeln!(
6236        code,
6237        "            // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6238    )?;
6239    writeln!(code, "            let tg_size = MTLSize::new(128, 1, 1);")?;
6240    writeln!(
6241        code,
6242        "            let q_blocks = ((num_tokens + 7) / 8) as u64;"
6243    )?;
6244    writeln!(
6245        code,
6246        "            let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6247    )?;
6248    writeln!(
6249        code,
6250        "            enc.dispatch_thread_groups(grid_size, tg_size);"
6251    )?;
6252    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6253    writeln!(code, "            return;")?;
6254    writeln!(code, "        }}")?;
6255    writeln!(code, "        let pipe = &self.attention_batch_pipeline;")?;
6256    writeln!(code, "        enc.set_compute_pipeline_state(pipe);")?;
6257    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
6258    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
6259    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
6260    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
6261    writeln!(
6262        code,
6263        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6264    )?;
6265    writeln!(
6266        code,
6267        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6268    )?;
6269    writeln!(
6270        code,
6271        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6272    )?;
6273    writeln!(
6274        code,
6275        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6276    )?;
6277    writeln!(
6278        code,
6279        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6280    )?;
6281    writeln!(
6282        code,
6283        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6284    )?;
6285    writeln!(
6286        code,
6287        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6288    )?;
6289    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6290    writeln!(
6291        code,
6292        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6293    )?;
6294    writeln!(
6295        code,
6296        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6297    )?;
6298    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6299    writeln!(code, "    }}")?;
6300    writeln!(code)?;
6301
6302    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
6303    writeln!(
6304        code,
6305        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6306    )?;
6307    writeln!(
6308        code,
6309        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6310    )?;
6311    writeln!(
6312        code,
6313        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6314    )?;
6315    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6316    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6317    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
6318    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6319    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6320    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
6321    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
6322    writeln!(
6323        code,
6324        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6325    )?;
6326    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
6327    writeln!(
6328        code,
6329        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6330    )?;
6331    writeln!(
6332        code,
6333        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6334    )?;
6335    writeln!(
6336        code,
6337        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6338    )?;
6339    writeln!(
6340        code,
6341        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6342    )?;
6343    writeln!(
6344        code,
6345        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6346    )?;
6347    writeln!(
6348        code,
6349        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6350    )?;
6351    writeln!(
6352        code,
6353        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6354    )?;
6355    writeln!(
6356        code,
6357        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6358    )?;
6359    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6360    writeln!(
6361        code,
6362        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6363    )?;
6364    writeln!(
6365        code,
6366        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6367    )?;
6368    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6369    writeln!(code, "    }}")?;
6370    writeln!(code)?;
6371
6372    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
6373    writeln!(
6374        code,
6375        "    /// Dispatch fused K+V cache copy in one kernel launch."
6376    )?;
6377    writeln!(
6378        code,
6379        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6380    )?;
6381    writeln!(
6382        code,
6383        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6384    )?;
6385    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6386    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6387    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6388    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6389    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
6390    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
6391    writeln!(
6392        code,
6393        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6394    )?;
6395    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6396    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
6397    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
6398    writeln!(
6399        code,
6400        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6401    )?;
6402    writeln!(
6403        code,
6404        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6405    )?;
6406    writeln!(
6407        code,
6408        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6409    )?;
6410    writeln!(
6411        code,
6412        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6413    )?;
6414    writeln!(
6415        code,
6416        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6417    )?;
6418    writeln!(
6419        code,
6420        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6421    )?;
6422    writeln!(
6423        code,
6424        "        let total = num_tokens * kv_dim * 2;  // K + V"
6425    )?;
6426    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6427    writeln!(
6428        code,
6429        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6430    )?;
6431    writeln!(
6432        code,
6433        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6434    )?;
6435    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6436    writeln!(code, "    }}")?;
6437
6438    writeln!(code, "}}")?;
6439    writeln!(code)?;
6440
6441    Ok(())
6442}
6443
6444fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6445    writeln!(
6446        code,
6447        "// ── Helper functions ──────────────────────────────────"
6448    )?;
6449    writeln!(code)?;
6450    writeln!(
6451        code,
6452        "/// Create a compute pipeline from a named function in the Metal library."
6453    )?;
6454    writeln!(
6455        code,
6456        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6457    )?;
6458    writeln!(
6459        code,
6460        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6461    )?;
6462    writeln!(
6463        code,
6464        "    device.new_compute_pipeline_state_with_function(&func)"
6465    )?;
6466    writeln!(
6467        code,
6468        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6469    )?;
6470    writeln!(code, "}}")?;
6471    writeln!(code)?;
6472
6473    Ok(())
6474}
6475
6476// ---------------------------------------------------------------------------
6477// main.rs generation
6478// ---------------------------------------------------------------------------
6479
6480fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6481    let _sanitized = sanitize_name(model_name);
6482    let _vocab = config.vocab_size;
6483
6484    let mut code = String::with_capacity(16 * 1024);
6485    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6486    writeln!(
6487        code,
6488        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6489    )?;
6490    writeln!(code)?;
6491    writeln!(code, "mod model;")?;
6492    writeln!(code)?;
6493    writeln!(code, "use std::io::Write;")?;
6494    writeln!(code, "use std::time::Instant;")?;
6495    writeln!(code, "use serde::Deserialize;")?;
6496    writeln!(code)?;
6497
6498    // -- main function --
6499    writeln!(code, "fn main() {{")?;
6500    writeln!(
6501        code,
6502        "    let args: Vec<String> = std::env::args().collect();"
6503    )?;
6504    writeln!(code)?;
6505    writeln!(
6506        code,
6507        "    // Detect --serve mode (only requires weights + tokenizer)"
6508    )?;
6509    writeln!(
6510        code,
6511        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
6512    )?;
6513    writeln!(code)?;
6514    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
6515    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6516    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6517    writeln!(code, "        std::process::exit(1);")?;
6518    writeln!(code, "    }}")?;
6519    writeln!(code)?;
6520    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
6521    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6522    writeln!(code, "        std::process::exit(1);")?;
6523    writeln!(code, "    }}")?;
6524    writeln!(code)?;
6525    writeln!(code, "    let weights_path = &args[1];")?;
6526    writeln!(code, "    let tokenizer_path = &args[2];")?;
6527    writeln!(code)?;
6528    writeln!(code, "    // Parse optional flags")?;
6529    writeln!(code, "    let mut max_tokens: usize = 128;")?;
6530    writeln!(code, "    let mut port: u16 = 8080;")?;
6531    writeln!(
6532        code,
6533        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6534    )?;
6535    writeln!(
6536        code,
6537        "    let profile = args.iter().any(|a| a == \"--profile\");"
6538    )?;
6539    writeln!(code, "    let mut i = 3;")?;
6540    writeln!(code, "    while i < args.len() {{")?;
6541    writeln!(
6542        code,
6543        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6544    )?;
6545    writeln!(
6546        code,
6547        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
6548    )?;
6549    writeln!(code, "            i += 2;")?;
6550    writeln!(
6551        code,
6552        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6553    )?;
6554    writeln!(
6555        code,
6556        "            port = args[i + 1].parse().unwrap_or(8080);"
6557    )?;
6558    writeln!(code, "            i += 2;")?;
6559    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
6560    writeln!(code, "            i += 1;")?;
6561    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
6562    writeln!(code, "            i += 1;")?;
6563    writeln!(code, "        }} else {{")?;
6564    writeln!(code, "            i += 1;")?;
6565    writeln!(code, "        }}")?;
6566    writeln!(code, "    }}")?;
6567    writeln!(code)?;
6568
6569    // -- load model (shared by both modes) --
6570    writeln!(
6571        code,
6572        "    // Memory-map weights for zero-copy loading on Apple Silicon"
6573    )?;
6574    writeln!(
6575        code,
6576        "    let weights_file = std::fs::File::open(weights_path)"
6577    )?;
6578    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6579    writeln!(
6580        code,
6581        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6582    )?;
6583    writeln!(code)?;
6584    writeln!(code, "    // Load tokenizer")?;
6585    writeln!(
6586        code,
6587        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6588    )?;
6589    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6590    writeln!(code)?;
6591    writeln!(code, "    // Create Metal model")?;
6592    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
6593    writeln!(
6594        code,
6595        "    let mut model = model::MetalModel::new(&weights_mmap);"
6596    )?;
6597    writeln!(code)?;
6598
6599    // -- branch: serve vs CLI --
6600    writeln!(code, "    if serve_mode {{")?;
6601    writeln!(code, "        serve(model, tokenizer, port);")?;
6602    writeln!(code, "    }} else {{")?;
6603    writeln!(code, "        let prompt = &args[3];")?;
6604    writeln!(
6605        code,
6606        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6607    )?;
6608    writeln!(code, "    }}")?;
6609    writeln!(code, "}}")?;
6610    writeln!(code)?;
6611
6612    // -- cli_mode function --
6613    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6614    writeln!(code, "    // Tokenize prompt")?;
6615    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
6616    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6617    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6618    writeln!(code)?;
6619    writeln!(
6620        code,
6621        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6622    )?;
6623    writeln!(
6624        code,
6625        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
6626    )?;
6627    writeln!(
6628        code,
6629        "    // The last token uses synchronous forward() to get logits."
6630    )?;
6631    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
6632    writeln!(code, "    let prefill_start = Instant::now();")?;
6633    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
6634    writeln!(
6635        code,
6636        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6637    )?;
6638    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
6639    writeln!(code, "    }} else {{")?;
6640    writeln!(code, "        model.forward(prompt_tokens[0])")?;
6641    writeln!(code, "    }};")?;
6642    writeln!(
6643        code,
6644        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6645    )?;
6646    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
6647    writeln!(
6648        code,
6649        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6650    )?;
6651    writeln!(
6652        code,
6653        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6654    )?;
6655    writeln!(code)?;
6656    writeln!(code, "    // Generate tokens")?;
6657    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6658    writeln!(code, "    let gen_start = Instant::now();")?;
6659    writeln!(code, "    let mut generated_count: usize = 0;")?;
6660    writeln!(code)?;
6661    writeln!(code, "    for _ in 0..max_tokens {{")?;
6662    writeln!(
6663        code,
6664        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6665    )?;
6666    writeln!(code, "            if !quiet {{")?;
6667    writeln!(code, "                print!(\"{{}}\", text);")?;
6668    writeln!(code, "                std::io::stdout().flush().ok();")?;
6669    writeln!(code, "            }}")?;
6670    writeln!(code, "        }}")?;
6671    writeln!(code, "        generated_count += 1;")?;
6672    writeln!(code)?;
6673    writeln!(
6674        code,
6675        "        // Use profiling forward for first token when --profile is set"
6676    )?;
6677    writeln!(
6678        code,
6679        "        let logits = if profile && generated_count == 1 {{"
6680    )?;
6681    writeln!(code, "            model.forward_profile(next_token)")?;
6682    writeln!(code, "        }} else {{")?;
6683    writeln!(code, "            model.forward(next_token)")?;
6684    writeln!(code, "        }};")?;
6685    writeln!(code, "        next_token = argmax(&logits);")?;
6686    writeln!(code)?;
6687    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
6688    writeln!(code, "        if next_token == 2 {{")?;
6689    writeln!(code, "            break;")?;
6690    writeln!(code, "        }}")?;
6691    writeln!(code)?;
6692    writeln!(
6693        code,
6694        "        // Yield between tokens to reduce sustained GPU thermal load."
6695    )?;
6696    writeln!(
6697        code,
6698        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6699    )?;
6700    writeln!(
6701        code,
6702        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6703    )?;
6704    writeln!(
6705        code,
6706        "        // briefly, providing a micro-break that helps sustain peak throughput."
6707    )?;
6708    writeln!(code, "        std::thread::yield_now();")?;
6709    writeln!(code, "    }}")?;
6710    writeln!(code, "    if !quiet {{")?;
6711    writeln!(code, "        println!();")?;
6712    writeln!(code, "    }}")?;
6713    writeln!(
6714        code,
6715        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6716    )?;
6717    writeln!(
6718        code,
6719        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6720    )?;
6721    writeln!(
6722        code,
6723        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6724    )?;
6725    writeln!(code, "}}")?;
6726    writeln!(code)?;
6727
6728    // -- argmax helper --
6729    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6730    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6731    writeln!(code, "    logits.iter()")?;
6732    writeln!(code, "        .enumerate()")?;
6733    writeln!(
6734        code,
6735        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6736    )?;
6737    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6738    writeln!(code, "        .unwrap_or(0)")?;
6739    writeln!(code, "}}")?;
6740    writeln!(code)?;
6741
6742    // -- Request/Response types for OpenAI API --
6743    writeln!(
6744        code,
6745        "// -----------------------------------------------------------------------"
6746    )?;
6747    writeln!(code, "// OpenAI-compatible API server")?;
6748    writeln!(
6749        code,
6750        "// -----------------------------------------------------------------------"
6751    )?;
6752    writeln!(code)?;
6753    writeln!(code, "#[derive(Deserialize)]")?;
6754    writeln!(code, "struct ChatRequest {{")?;
6755    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6756    writeln!(code, "    #[serde(default)]")?;
6757    writeln!(code, "    stream: Option<bool>,")?;
6758    writeln!(code, "    #[serde(default)]")?;
6759    writeln!(code, "    max_tokens: Option<usize>,")?;
6760    writeln!(code, "    #[serde(default)]")?;
6761    writeln!(code, "    temperature: Option<f32>,")?;
6762    writeln!(code, "    #[serde(default)]")?;
6763    writeln!(code, "    model: Option<String>,")?;
6764    writeln!(code, "}}")?;
6765    writeln!(code)?;
6766    writeln!(code, "#[derive(Deserialize)]")?;
6767    writeln!(code, "struct ChatMessage {{")?;
6768    writeln!(code, "    role: String,")?;
6769    writeln!(code, "    content: String,")?;
6770    writeln!(code, "}}")?;
6771    writeln!(code)?;
6772
6773    // -- format_chat_messages --
6774    writeln!(
6775        code,
6776        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6777    )?;
6778    writeln!(code, "    let mut prompt = String::new();")?;
6779    writeln!(code, "    for msg in messages {{")?;
6780    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6781    writeln!(code, "    }}")?;
6782    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6783    writeln!(code, "    prompt")?;
6784    writeln!(code, "}}")?;
6785    writeln!(code)?;
6786
6787    // -- prefill helper --
6788    writeln!(
6789        code,
6790        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6791    )?;
6792    writeln!(code, "    let len = tokens.len();")?;
6793    writeln!(code, "    if len > 1 {{")?;
6794    writeln!(
6795        code,
6796        "        model.forward_prefill_batch(&tokens[..len - 1]);"
6797    )?;
6798    writeln!(code, "    }}")?;
6799    writeln!(code, "    model.forward(tokens[len - 1])")?;
6800    writeln!(code, "}}")?;
6801    writeln!(code)?;
6802
6803    // -- serve function --
6804    writeln!(
6805        code,
6806        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6807    )?;
6808    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6809    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
6810    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6811    writeln!(
6812        code,
6813        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6814    )?;
6815    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
6816    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
6817    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
6818    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
6819    writeln!(code)?;
6820    writeln!(code, "    for request in server.incoming_requests() {{")?;
6821    writeln!(code, "        let method = request.method().to_string();")?;
6822    writeln!(code, "        let url = request.url().to_string();")?;
6823    writeln!(code)?;
6824    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
6825
6826    // -- POST /v1/chat/completions --
6827    writeln!(
6828        code,
6829        "            (\"POST\", \"/v1/chat/completions\") => {{"
6830    )?;
6831    writeln!(
6832        code,
6833        "                handle_chat_completion(&mut model, &tokenizer, request);"
6834    )?;
6835    writeln!(code, "            }}")?;
6836
6837    // -- GET /v1/models --
6838    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
6839    writeln!(code, "                let body = serde_json::json!({{")?;
6840    writeln!(code, "                    \"object\": \"list\",")?;
6841    writeln!(code, "                    \"data\": [{{")?;
6842    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
6843    writeln!(code, "                        \"object\": \"model\",")?;
6844    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
6845    writeln!(code, "                    }}]")?;
6846    writeln!(code, "                }});")?;
6847    writeln!(
6848        code,
6849        "                let resp = tiny_http::Response::from_string(body.to_string())"
6850    )?;
6851    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6852    writeln!(code, "                request.respond(resp).ok();")?;
6853    writeln!(code, "            }}")?;
6854
6855    // -- GET /health --
6856    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
6857    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6858    writeln!(code, "                request.respond(resp).ok();")?;
6859    writeln!(code, "            }}")?;
6860
6861    // -- 404 --
6862    writeln!(code, "            _ => {{")?;
6863    writeln!(
6864        code,
6865        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
6866    )?;
6867    writeln!(code, "                    .with_status_code(404);")?;
6868    writeln!(code, "                request.respond(resp).ok();")?;
6869    writeln!(code, "            }}")?;
6870    writeln!(code, "        }}")?;
6871    writeln!(code, "    }}")?;
6872    writeln!(code, "}}")?;
6873    writeln!(code)?;
6874
6875    // -- handle_chat_completion --
6876    writeln!(code, "fn handle_chat_completion(")?;
6877    writeln!(code, "    model: &mut model::MetalModel,")?;
6878    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6879    writeln!(code, "    mut request: tiny_http::Request,")?;
6880    writeln!(code, ") {{")?;
6881    writeln!(code, "    // Read request body")?;
6882    writeln!(code, "    let mut body = String::new();")?;
6883    writeln!(
6884        code,
6885        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6886    )?;
6887    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6888    writeln!(code, "            .with_status_code(400);")?;
6889    writeln!(code, "        request.respond(resp).ok();")?;
6890    writeln!(code, "        return;")?;
6891    writeln!(code, "    }}")?;
6892    writeln!(code)?;
6893    writeln!(code, "    // Parse JSON")?;
6894    writeln!(
6895        code,
6896        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6897    )?;
6898    writeln!(code, "        Ok(r) => r,")?;
6899    writeln!(code, "        Err(e) => {{")?;
6900    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6901    writeln!(code, "                .with_status_code(400);")?;
6902    writeln!(code, "            request.respond(resp).ok();")?;
6903    writeln!(code, "            return;")?;
6904    writeln!(code, "        }}")?;
6905    writeln!(code, "    }};")?;
6906    writeln!(code)?;
6907    writeln!(
6908        code,
6909        "    let prompt = format_chat_messages(&req.messages);"
6910    )?;
6911    writeln!(
6912        code,
6913        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6914    )?;
6915    writeln!(code, "        Ok(e) => e,")?;
6916    writeln!(code, "        Err(e) => {{")?;
6917    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6918    writeln!(code, "                .with_status_code(500);")?;
6919    writeln!(code, "            request.respond(resp).ok();")?;
6920    writeln!(code, "            return;")?;
6921    writeln!(code, "        }}")?;
6922    writeln!(code, "    }};")?;
6923    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6924    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6925    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6926    writeln!(
6927        code,
6928        "    let _temperature = req.temperature.unwrap_or(1.0);"
6929    )?;
6930    writeln!(code)?;
6931
6932    // -- Reset KV cache for each request --
6933    writeln!(code, "    model.reset();")?;
6934    writeln!(code)?;
6935
6936    // -- Prefill with timing --
6937    writeln!(code, "    let prefill_start = Instant::now();")?;
6938    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6939    writeln!(
6940        code,
6941        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6942    )?;
6943    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6944    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6945    writeln!(code)?;
6946
6947    writeln!(code, "    if stream {{")?;
6948
6949    // -- SSE streaming response --
6950    writeln!(
6951        code,
6952        "        // SSE streaming: generate tokens and build SSE body"
6953    )?;
6954    writeln!(code, "        let gen_start = Instant::now();")?;
6955    writeln!(code, "        let mut generated_count: usize = 0;")?;
6956    writeln!(code, "        let mut sse_body = String::new();")?;
6957    writeln!(code, "        for _ in 0..max_tokens {{")?;
6958    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6959    writeln!(
6960        code,
6961        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6962    )?;
6963    writeln!(
6964        code,
6965        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6966    )?;
6967    writeln!(
6968        code,
6969        "                // escaped includes surrounding quotes, strip them"
6970    )?;
6971    writeln!(
6972        code,
6973        "                let inner = &escaped[1..escaped.len()-1];"
6974    )?;
6975    writeln!(code, "                sse_body.push_str(&format!(")?;
6976    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6977    writeln!(code, "                    inner")?;
6978    writeln!(code, "                ));")?;
6979    writeln!(code, "            }}")?;
6980    writeln!(code, "            generated_count += 1;")?;
6981    writeln!(code, "            let logits = model.forward(next_token);")?;
6982    writeln!(code, "            next_token = argmax(&logits);")?;
6983    writeln!(code, "        }}")?;
6984    writeln!(
6985        code,
6986        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6987    )?;
6988    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6989    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
6990    writeln!(code)?;
6991    writeln!(
6992        code,
6993        "        // Final chunk with finish_reason, timing, and DONE sentinel"
6994    )?;
6995    writeln!(code, "        sse_body.push_str(&format!(")?;
6996    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6997    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6998    writeln!(code, "        ));")?;
6999    writeln!(code)?;
7000    writeln!(
7001        code,
7002        "        let resp = tiny_http::Response::from_string(sse_body)"
7003    )?;
7004    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
7005    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
7006    writeln!(code, "        request.respond(resp).ok();")?;
7007
7008    writeln!(code, "    }} else {{")?;
7009
7010    // -- Non-streaming response --
7011    writeln!(
7012        code,
7013        "        // Non-streaming: generate all tokens, return JSON"
7014    )?;
7015    writeln!(code, "        let gen_start = Instant::now();")?;
7016    writeln!(code, "        let mut generated_count: usize = 0;")?;
7017    writeln!(code, "        let mut generated = String::new();")?;
7018    writeln!(code, "        for _ in 0..max_tokens {{")?;
7019    writeln!(code, "            if next_token == 2 {{ break; }}")?;
7020    writeln!(
7021        code,
7022        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7023    )?;
7024    writeln!(code, "                generated.push_str(&text);")?;
7025    writeln!(code, "            }}")?;
7026    writeln!(code, "            generated_count += 1;")?;
7027    writeln!(code, "            let logits = model.forward(next_token);")?;
7028    writeln!(code, "            next_token = argmax(&logits);")?;
7029    writeln!(code, "        }}")?;
7030    writeln!(
7031        code,
7032        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7033    )?;
7034    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7035    writeln!(code)?;
7036    writeln!(code, "        let resp_json = serde_json::json!({{")?;
7037    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
7038    writeln!(code, "            \"object\": \"chat.completion\",")?;
7039    writeln!(code, "            \"choices\": [{{")?;
7040    writeln!(code, "                \"index\": 0,")?;
7041    writeln!(code, "                \"message\": {{")?;
7042    writeln!(code, "                    \"role\": \"assistant\",")?;
7043    writeln!(code, "                    \"content\": generated")?;
7044    writeln!(code, "                }},")?;
7045    writeln!(code, "                \"finish_reason\": \"stop\"")?;
7046    writeln!(code, "            }}],")?;
7047    writeln!(code, "            \"usage\": {{")?;
7048    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
7049    writeln!(
7050        code,
7051        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
7052    )?;
7053    writeln!(
7054        code,
7055        "                \"generation_tokens\": generated_count,"
7056    )?;
7057    writeln!(
7058        code,
7059        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
7060    )?;
7061    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
7062    writeln!(code, "            }}")?;
7063    writeln!(code, "        }});")?;
7064    writeln!(
7065        code,
7066        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
7067    )?;
7068    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7069    writeln!(code, "        request.respond(resp).ok();")?;
7070    writeln!(code, "    }}")?;
7071    writeln!(code, "}}")?;
7072
7073    Ok(code)
7074}
7075
7076// ---------------------------------------------------------------------------
7077// Tests
7078// ---------------------------------------------------------------------------
7079
7080#[cfg(test)]
7081mod tests {
7082    use super::*;
7083    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7084
7085    fn minimal_config() -> ModelConfig {
7086        ModelConfig {
7087            architecture: Architecture::Llama,
7088            hidden_size: 64,
7089            intermediate_size: 128,
7090            num_layers: 2,
7091            num_attention_heads: 4,
7092            num_kv_heads: 4,
7093            head_dim: 16,
7094            vocab_size: 256,
7095            max_seq_len: 512,
7096            rms_norm_eps: 1e-5,
7097            rope_theta: 10000.0,
7098            dtype: DType::F32,
7099            sliding_window_size: None,
7100            qkv_bias: false,
7101        }
7102    }
7103
7104    fn minimal_graph() -> Graph {
7105        Graph::new("test-metal").with_config(minimal_config())
7106    }
7107
7108    #[test]
7109    fn generate_metal_project_creates_files() {
7110        let dir = tempfile::tempdir().unwrap();
7111        let graph = minimal_graph();
7112        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7113
7114        assert!(
7115            dir.path().join("Cargo.toml").exists(),
7116            "Cargo.toml should be created"
7117        );
7118        assert!(
7119            dir.path().join("src/model.rs").exists(),
7120            "src/model.rs should be created"
7121        );
7122        assert!(
7123            dir.path().join("src/main.rs").exists(),
7124            "src/main.rs should be created"
7125        );
7126        assert!(
7127            dir.path().join("shaders/kernels.metal").exists(),
7128            "shaders/kernels.metal should be created"
7129        );
7130    }
7131
7132    #[test]
7133    fn generated_cargo_toml_has_metal_dep() {
7134        let toml = generate_cargo_toml("my-model");
7135        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7136        assert!(
7137            toml.contains("tokenizers"),
7138            "Cargo.toml should depend on tokenizers"
7139        );
7140        assert!(
7141            toml.contains("memmap2"),
7142            "Cargo.toml should depend on memmap2"
7143        );
7144        assert!(toml.contains("half"), "Cargo.toml should depend on half");
7145    }
7146
7147    #[test]
7148    fn generated_model_rs_contains_metal_code() {
7149        let config = minimal_config();
7150        let model_rs = generate_model_rs(&config).unwrap();
7151
7152        assert!(
7153            model_rs.contains("pub struct MetalModel"),
7154            "model.rs should define MetalModel struct"
7155        );
7156        assert!(
7157            model_rs.contains("matmul_pipeline: ComputePipelineState"),
7158            "MetalModel should have matmul_pipeline field"
7159        );
7160        assert!(
7161            model_rs.contains("Device::system_default()"),
7162            "model.rs should use Metal device"
7163        );
7164        assert!(
7165            model_rs.contains("new_library_with_source"),
7166            "model.rs should compile Metal shaders"
7167        );
7168        assert!(
7169            model_rs.contains("fn new(weights: &[u8])"),
7170            "MetalModel should implement new()"
7171        );
7172        assert!(
7173            model_rs.contains("fn forward(&mut self, token_id: u32)"),
7174            "MetalModel should implement forward()"
7175        );
7176    }
7177
7178    #[test]
7179    fn generated_shaders_contain_kernel_names() {
7180        let shaders = generate_metal_shaders(&minimal_config());
7181
7182        assert!(
7183            shaders.contains("kernel void matmul_vec"),
7184            "shaders should contain matmul_vec kernel"
7185        );
7186        assert!(
7187            shaders.contains("kernel void rms_norm"),
7188            "shaders should contain rms_norm kernel"
7189        );
7190        assert!(
7191            shaders.contains("kernel void rope"),
7192            "shaders should contain rope kernel"
7193        );
7194        assert!(
7195            shaders.contains("kernel void softmax"),
7196            "shaders should contain softmax kernel"
7197        );
7198        assert!(
7199            shaders.contains("kernel void silu_mul("),
7200            "shaders should contain silu_mul kernel"
7201        );
7202        assert!(
7203            shaders.contains("kernel void silu_mul_fused"),
7204            "shaders should contain silu_mul_fused kernel"
7205        );
7206        assert!(
7207            shaders.contains("kernel void elementwise_add"),
7208            "shaders should contain elementwise_add kernel"
7209        );
7210        assert!(
7211            shaders.contains("kernel void attention"),
7212            "shaders should contain attention kernel"
7213        );
7214        assert!(
7215            shaders.contains("kernel void add_inplace"),
7216            "shaders should contain add_inplace kernel"
7217        );
7218        assert!(
7219            shaders.contains("kernel void copy_buffer"),
7220            "shaders should contain copy_buffer kernel"
7221        );
7222        assert!(
7223            shaders.contains("kernel void copy_offset"),
7224            "shaders should contain copy_offset kernel"
7225        );
7226    }
7227
7228    #[test]
7229    fn generated_shaders_use_simdgroup_features() {
7230        let shaders = generate_metal_shaders(&minimal_config());
7231
7232        assert!(
7233            shaders.contains("threadgroup_barrier"),
7234            "shaders should use threadgroup barriers"
7235        );
7236        assert!(
7237            shaders.contains("threadgroup float"),
7238            "shaders should use threadgroup shared memory"
7239        );
7240        assert!(
7241            shaders.contains("thread_index_in_threadgroup"),
7242            "shaders should use threadgroup indexing"
7243        );
7244        assert!(
7245            shaders.contains("simd_sum"),
7246            "shaders should use simd_sum for warp-level reduction"
7247        );
7248        assert!(
7249            shaders.contains("simd_max"),
7250            "attention kernel should use simd_max for cooperative softmax"
7251        );
7252        assert!(
7253            shaders.contains("thread_index_in_simdgroup"),
7254            "shaders should use simdgroup lane indexing"
7255        );
7256        assert!(
7257            shaders.contains("simdgroup_index_in_threadgroup"),
7258            "shaders should use simdgroup indexing within threadgroup"
7259        );
7260        assert!(
7261            shaders.contains("float4"),
7262            "matmul_vec should use float4 vectorized loads"
7263        );
7264    }
7265
7266    #[test]
7267    fn generated_main_rs_has_tokenizer_usage() {
7268        let config = minimal_config();
7269        let main_rs = generate_main_rs("test-model", &config).unwrap();
7270
7271        assert!(
7272            main_rs.contains("tokenizers::Tokenizer"),
7273            "main.rs should use tokenizers crate"
7274        );
7275        assert!(
7276            main_rs.contains("MetalModel::new"),
7277            "main.rs should call MetalModel::new"
7278        );
7279        assert!(
7280            main_rs.contains("model.forward"),
7281            "main.rs should call model.forward"
7282        );
7283        assert!(
7284            main_rs.contains("memmap2"),
7285            "main.rs should use memmap2 for zero-copy weight loading"
7286        );
7287    }
7288
7289    #[test]
7290    fn missing_config_returns_error() {
7291        let dir = tempfile::tempdir().unwrap();
7292        let graph = Graph::new("no-config");
7293        let result = generate_metal_project(&graph, dir.path(), "fail");
7294        assert!(
7295            matches!(result, Err(MetalCodegenError::MissingConfig)),
7296            "should fail with MissingConfig when graph has no config"
7297        );
7298    }
7299
7300    #[test]
7301    fn sanitize_name_works() {
7302        assert_eq!(sanitize_name("My Model!"), "my-model");
7303        assert_eq!(sanitize_name("test_model"), "test-model");
7304        assert_eq!(sanitize_name("simple"), "simple");
7305    }
7306
7307    #[test]
7308    fn generated_forward_uses_single_command_buffer() {
7309        let config = minimal_config();
7310        let model_rs = generate_model_rs(&config).unwrap();
7311
7312        // The forward function should create exactly one command buffer.
7313        // Use the exact signature to avoid matching forward_prefill/forward_profile.
7314        let forward_start = model_rs
7315            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7316            .unwrap();
7317        let forward_body = &model_rs[forward_start..];
7318        // End at the next pub/private method
7319        let forward_end = forward_body
7320            .find("\n    pub fn forward_profile")
7321            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7322            .or_else(|| forward_body.find("\n    fn dispatch_"))
7323            .unwrap_or(forward_body.len());
7324        let forward_code = &forward_body[..forward_end];
7325
7326        // Should have exactly one new_command_buffer call
7327        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7328        assert_eq!(
7329            cmd_buf_count, 1,
7330            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7331        );
7332
7333        // Should have exactly one commit call
7334        let commit_count = forward_code.matches("cmd.commit()").count();
7335        assert_eq!(
7336            commit_count, 1,
7337            "forward() should commit exactly once, found {commit_count}"
7338        );
7339
7340        // Should wait: once for cmd + possibly once for prev_cmd drain
7341        let wait_count = forward_code.matches("wait_until_completed()").count();
7342        assert!(
7343            wait_count >= 1 && wait_count <= 2,
7344            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7345        );
7346    }
7347
7348    #[test]
7349    fn generated_model_has_preallocated_working_buffers() {
7350        let config = minimal_config();
7351        let model_rs = generate_model_rs(&config).unwrap();
7352
7353        for buf_name in &[
7354            "normed_buf",
7355            "qkv_buf",
7356            "attn_out_buf",
7357            "attn_proj_buf",
7358            "gate_up_buf",
7359            "ffn_hidden_buf",
7360            "ffn_out_buf",
7361            "add_tmp_buf",
7362        ] {
7363            assert!(
7364                model_rs.contains(&format!("{buf_name}: Buffer")),
7365                "MetalModel should have pre-allocated {buf_name} field"
7366            );
7367        }
7368    }
7369
7370    #[test]
7371    fn generated_dispatch_helpers_take_compute_encoder_ref() {
7372        let config = minimal_config();
7373        let model_rs = generate_model_rs(&config).unwrap();
7374
7375        for method in &[
7376            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7377            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7378            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7379            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7380            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7381            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7382            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7383            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7384            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7385            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7386            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7387            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7388            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7389        ] {
7390            assert!(
7391                model_rs.contains(method),
7392                "model.rs should contain dispatch helper: {method}"
7393            );
7394        }
7395    }
7396
7397    #[test]
7398    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7399        let config = minimal_config();
7400        let model_rs = generate_model_rs(&config).unwrap();
7401
7402        // Find dispatch helpers section and check none create their own encoders
7403        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7404        let helpers_code = &model_rs[helpers_start..];
7405
7406        // None of the dispatch_ helpers should call new_command_buffer
7407        assert!(
7408            !helpers_code.contains("self.queue.new_command_buffer()"),
7409            "dispatch helpers should not create their own command buffers"
7410        );
7411
7412        // None should create their own compute encoders
7413        assert!(
7414            !helpers_code.contains("new_compute_command_encoder()"),
7415            "dispatch helpers should not create their own compute encoders"
7416        );
7417
7418        // None should call end_encoding
7419        assert!(
7420            !helpers_code.contains("end_encoding()"),
7421            "dispatch helpers should not call end_encoding"
7422        );
7423
7424        // None should call commit or wait
7425        assert!(
7426            !helpers_code.contains(".commit()"),
7427            "dispatch helpers should not commit command buffers"
7428        );
7429        assert!(
7430            !helpers_code.contains("wait_until_completed"),
7431            "dispatch helpers should not wait on command buffers"
7432        );
7433    }
7434
7435    #[test]
7436    fn generated_forward_batches_compute_encoders() {
7437        let config = minimal_config();
7438        let model_rs = generate_model_rs(&config).unwrap();
7439
7440        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
7441        let forward_start = model_rs
7442            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7443            .unwrap();
7444        let forward_body = &model_rs[forward_start..];
7445        let forward_end = forward_body
7446            .find("\n    pub fn forward_profile")
7447            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7448            .or_else(|| forward_body.find("\n    fn dispatch_"))
7449            .unwrap_or(forward_body.len());
7450        let forward_code = &forward_body[..forward_end];
7451
7452        // Forward should not allocate new buffers
7453        assert!(
7454            !forward_code.contains("device.new_buffer"),
7455            "forward() should not allocate new buffers per call"
7456        );
7457
7458        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
7459        // Copy operations use compute copy kernels instead of blit encoders.
7460        let compute_encoder_count = forward_code
7461            .matches("new_compute_command_encoder()")
7462            .count();
7463        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7464
7465        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
7466        assert_eq!(
7467            compute_encoder_count, 1,
7468            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7469        );
7470        assert_eq!(
7471            blit_encoder_count, 0,
7472            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7473        );
7474    }
7475
7476    #[test]
7477    fn generated_forward_uses_add_inplace() {
7478        let config = minimal_config();
7479        let model_rs = generate_model_rs(&config).unwrap();
7480
7481        // Should use in-place add (no blit copy-back needed)
7482        assert!(
7483            model_rs.contains("dispatch_add_inplace"),
7484            "forward() should use dispatch_add_inplace for residual connections"
7485        );
7486        assert!(
7487            model_rs.contains("add_inplace_pipeline"),
7488            "MetalModel should have add_inplace_pipeline"
7489        );
7490    }
7491
7492    fn minimal_q8_config() -> ModelConfig {
7493        ModelConfig {
7494            architecture: Architecture::Llama,
7495            hidden_size: 64,
7496            intermediate_size: 128,
7497            num_layers: 2,
7498            num_attention_heads: 4,
7499            num_kv_heads: 4,
7500            head_dim: 16,
7501            vocab_size: 256,
7502            max_seq_len: 512,
7503            rms_norm_eps: 1e-5,
7504            rope_theta: 10000.0,
7505            dtype: DType::Q8_0,
7506            sliding_window_size: None,
7507            qkv_bias: false,
7508        }
7509    }
7510
7511    #[test]
7512    fn generated_shaders_contain_q8_kernel() {
7513        let shaders = generate_metal_shaders(&minimal_config());
7514
7515        assert!(
7516            shaders.contains("kernel void matmul_vec_q8"),
7517            "shaders should contain matmul_vec_q8 kernel"
7518        );
7519        assert!(
7520            shaders.contains("device const uchar* matrix"),
7521            "matmul_vec_q8 should accept raw Q8_0 bytes"
7522        );
7523        assert!(
7524            shaders.contains("packed_short4"),
7525            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7526        );
7527        assert!(
7528            shaders.contains("as_type<char2>"),
7529            "matmul_vec_q8 should bitcast short lanes to char2"
7530        );
7531        assert!(
7532            shaders.contains("device const half*"),
7533            "matmul_vec_q8 should read f16 scale via half pointer"
7534        );
7535    }
7536
7537    #[test]
7538    fn generated_model_uses_fused_qkv_projections() {
7539        let config = minimal_config();
7540        let model_rs = generate_model_rs(&config).unwrap();
7541
7542        // Should have fused QKV weight in layer buffers
7543        assert!(
7544            model_rs.contains("qkv_weight: Buffer"),
7545            "LayerBuffers should have fused qkv_weight field"
7546        );
7547        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
7548        assert!(
7549            !model_rs.contains("    q_weight: Buffer"),
7550            "LayerBuffers should not have separate q_weight field"
7551        );
7552        assert!(
7553            !model_rs.contains("    k_weight: Buffer"),
7554            "LayerBuffers should not have separate k_weight field"
7555        );
7556        assert!(
7557            !model_rs.contains("    v_weight: Buffer"),
7558            "LayerBuffers should not have separate v_weight field"
7559        );
7560
7561        // Should have fused gate_up_weight
7562        assert!(
7563            model_rs.contains("gate_up_weight: Buffer"),
7564            "LayerBuffers should have fused gate_up_weight field"
7565        );
7566        // Should NOT have separate gate/up weight fields
7567        assert!(
7568            !model_rs.contains("    gate_weight: Buffer"),
7569            "LayerBuffers should not have separate gate_weight field"
7570        );
7571        assert!(
7572            !model_rs.contains("    up_weight: Buffer"),
7573            "LayerBuffers should not have separate up_weight field"
7574        );
7575
7576        // Should have fused working buffers
7577        assert!(
7578            model_rs.contains("qkv_buf: Buffer"),
7579            "MetalModel should have fused qkv_buf"
7580        );
7581        assert!(
7582            model_rs.contains("gate_up_buf: Buffer"),
7583            "MetalModel should have fused gate_up_buf"
7584        );
7585
7586        // Forward pass should use fused dispatch
7587        assert!(
7588            model_rs.contains("dispatch_silu_mul_fused"),
7589            "forward pass should use dispatch_silu_mul_fused"
7590        );
7591        assert!(
7592            model_rs.contains("dispatch_rope_offset"),
7593            "forward pass should use dispatch_rope_offset for fused QKV"
7594        );
7595        assert!(
7596            model_rs.contains("dispatch_attention_offset"),
7597            "forward pass should use dispatch_attention_offset for fused QKV"
7598        );
7599    }
7600
7601    #[test]
7602    fn q8_model_has_matmul_q8_pipeline() {
7603        let config = minimal_q8_config();
7604        let model_rs = generate_model_rs(&config).unwrap();
7605
7606        assert!(
7607            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7608            "MetalModel should have matmul_q8_pipeline field"
7609        );
7610        assert!(
7611            model_rs.contains("matmul_q8_pipeline,"),
7612            "MetalModel Self should include matmul_q8_pipeline"
7613        );
7614    }
7615
7616    #[test]
7617    fn q8_model_uses_dispatch_matmul_q8() {
7618        let config = minimal_q8_config();
7619        let model_rs = generate_model_rs(&config).unwrap();
7620
7621        assert!(
7622            model_rs.contains("dispatch_matmul_q8"),
7623            "Q8_0 model should use dispatch_matmul_q8 for projections"
7624        );
7625        assert!(
7626            model_rs.contains("fn dispatch_matmul_q8"),
7627            "model.rs should define dispatch_matmul_q8 method"
7628        );
7629    }
7630
7631    #[test]
7632    fn q8_model_loads_raw_bytes_not_dequantized() {
7633        let config = minimal_q8_config();
7634        let model_rs = generate_model_rs(&config).unwrap();
7635
7636        // Should NOT contain dequantization code
7637        assert!(
7638            !model_rs.contains("f16_to_f32"),
7639            "Q8_0 model should not dequantize weights to f32"
7640        );
7641        assert!(
7642            !model_rs.contains("f32_data"),
7643            "Q8_0 model should not create f32 weight data"
7644        );
7645
7646        // Should load raw Q8_0 bytes directly
7647        assert!(
7648            model_rs.contains("total_raw as u64"),
7649            "Q8_0 model should load raw bytes into Metal buffer"
7650        );
7651    }
7652
7653    #[test]
7654    fn q8_model_norms_stay_f32() {
7655        let config = minimal_q8_config();
7656        let model_rs = generate_model_rs(&config).unwrap();
7657
7658        // Norm weights should still use f32 buffers
7659        assert!(
7660            model_rs.contains("let attn_norm = next_f32_buffer"),
7661            "attn_norm should use f32 buffer even for Q8_0 models"
7662        );
7663        assert!(
7664            model_rs.contains("let ffn_norm = next_f32_buffer"),
7665            "ffn_norm should use f32 buffer even for Q8_0 models"
7666        );
7667        assert!(
7668            model_rs.contains("let norm_buf = next_f32_buffer"),
7669            "final norm should use f32 buffer even for Q8_0 models"
7670        );
7671    }
7672
7673    #[test]
7674    fn q8_model_uses_fused_weight_loading() {
7675        let config = minimal_q8_config();
7676        let model_rs = generate_model_rs(&config).unwrap();
7677
7678        // Should use fused Q8 buffer loading for QKV
7679        assert!(
7680            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7681            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7682        );
7683        // Should use fused Q8 buffer loading for gate+up
7684        assert!(
7685            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7686            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7687        );
7688        // Should still use regular q8 buffer for individual weights
7689        assert!(
7690            model_rs.contains("let o_weight = next_q8_buffer"),
7691            "Q8_0 model should use next_q8_buffer for O weight"
7692        );
7693        assert!(
7694            model_rs.contains("let down_weight = next_q8_buffer"),
7695            "Q8_0 model should use next_q8_buffer for down weight"
7696        );
7697    }
7698
7699    #[test]
7700    fn f32_model_does_not_use_q8_dispatch() {
7701        let config = minimal_config();
7702        let model_rs = generate_model_rs(&config).unwrap();
7703
7704        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
7705        let forward_start = model_rs
7706            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7707            .unwrap();
7708        let forward_body = &model_rs[forward_start..];
7709        let forward_end = forward_body
7710            .find("\n    fn dispatch_")
7711            .unwrap_or(forward_body.len());
7712        let forward_code = &forward_body[..forward_end];
7713
7714        assert!(
7715            !forward_code.contains("dispatch_matmul_q8"),
7716            "f32 model forward should not use dispatch_matmul_q8"
7717        );
7718    }
7719
7720    #[test]
7721    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7722        let config = minimal_q8_config();
7723        let model_rs = generate_model_rs(&config).unwrap();
7724
7725        assert!(
7726            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7727            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7728        );
7729    }
7730
7731    #[test]
7732    fn generated_model_has_double_buffered_prefill() {
7733        let config = minimal_config();
7734        let model_rs = generate_model_rs(&config).unwrap();
7735
7736        // MetalModel should have prev_cmd field for double-buffered prefill
7737        assert!(
7738            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7739            "MetalModel should have prev_cmd field for double-buffered prefill"
7740        );
7741
7742        // Should have forward_prefill method
7743        assert!(
7744            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7745            "MetalModel should have forward_prefill method"
7746        );
7747
7748        // forward() should drain prev_cmd at the start
7749        assert!(
7750            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7751            "forward() should drain prev_cmd from previous prefill"
7752        );
7753    }
7754
7755    #[test]
7756    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7757        let config = minimal_config();
7758        let main_rs = generate_main_rs("test-model", &config).unwrap();
7759
7760        assert!(
7761            main_rs.contains("forward_prefill"),
7762            "main.rs should use forward_prefill for intermediate prompt tokens"
7763        );
7764        assert!(
7765            main_rs.contains("double-buffered"),
7766            "main.rs should document double-buffered prefill"
7767        );
7768    }
7769
7770    #[test]
7771    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7772        let shaders = generate_metal_shaders(&minimal_config());
7773
7774        assert!(
7775            shaders.contains("packed_short4"),
7776            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7777        );
7778        assert!(
7779            shaders.contains("d0[0]"),
7780            "matmul_vec_q8 should index the wide pointer for row 0"
7781        );
7782        assert!(
7783            shaders.contains("as_type<char2>"),
7784            "matmul_vec_q8 should bitcast short lanes to char2"
7785        );
7786        assert!(
7787            shaders.contains("dot("),
7788            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7789        );
7790    }
7791
7792    // ── Q4_0 tests ──────────────────────────────────────────────────────
7793
7794    fn minimal_q4_config() -> ModelConfig {
7795        ModelConfig {
7796            architecture: Architecture::Llama,
7797            hidden_size: 64,
7798            intermediate_size: 128,
7799            num_layers: 2,
7800            num_attention_heads: 4,
7801            num_kv_heads: 4,
7802            head_dim: 16,
7803            vocab_size: 256,
7804            max_seq_len: 512,
7805            rms_norm_eps: 1e-5,
7806            rope_theta: 10000.0,
7807            dtype: DType::Q4_0,
7808            sliding_window_size: None,
7809            qkv_bias: false,
7810        }
7811    }
7812
7813    #[test]
7814    fn generated_shaders_contain_q4_kernel() {
7815        let shaders = generate_metal_shaders(&minimal_config());
7816
7817        assert!(
7818            shaders.contains("kernel void matmul_vec_q4"),
7819            "shaders should contain matmul_vec_q4 kernel"
7820        );
7821        assert!(
7822            shaders.contains("Q4_ROWS_PER_TG"),
7823            "shaders should define Q4_ROWS_PER_TG constant"
7824        );
7825        assert!(
7826            shaders.contains("Q4_ROWS_PER_SG"),
7827            "shaders should define Q4_ROWS_PER_SG constant"
7828        );
7829    }
7830
7831    #[test]
7832    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7833        let shaders = generate_metal_shaders(&minimal_config());
7834
7835        // Q4_0 kernel should use uchar4 for packed byte loads
7836        assert!(
7837            shaders.contains("uchar4"),
7838            "matmul_vec_q4 should use uchar4 for packed byte loads"
7839        );
7840        // Should unpack nibbles with &0xF and >>4
7841        assert!(
7842            shaders.contains("&0xF"),
7843            "matmul_vec_q4 should extract low nibble with &0xF"
7844        );
7845        assert!(
7846            shaders.contains(">>4"),
7847            "matmul_vec_q4 should extract high nibble with >>4"
7848        );
7849        // Should subtract 8 to convert unsigned to signed
7850        assert!(
7851            shaders.contains("-8)"),
7852            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7853        );
7854        // Should use 18-byte block size
7855        assert!(
7856            shaders.contains("blk * 18"),
7857            "matmul_vec_q4 should use 18-byte block stride"
7858        );
7859    }
7860
7861    #[test]
7862    fn q4_model_has_matmul_q4_pipeline() {
7863        let config = minimal_q4_config();
7864        let model_rs = generate_model_rs(&config).unwrap();
7865
7866        assert!(
7867            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7868            "MetalModel should have matmul_q4_pipeline field"
7869        );
7870        assert!(
7871            model_rs.contains("matmul_q4_pipeline,"),
7872            "MetalModel Self should include matmul_q4_pipeline"
7873        );
7874    }
7875
7876    #[test]
7877    fn q4_model_uses_dispatch_matmul_q4() {
7878        let config = minimal_q4_config();
7879        let model_rs = generate_model_rs(&config).unwrap();
7880
7881        assert!(
7882            model_rs.contains("dispatch_matmul_q4"),
7883            "Q4_0 model should use dispatch_matmul_q4 for projections"
7884        );
7885        assert!(
7886            model_rs.contains("fn dispatch_matmul_q4"),
7887            "model.rs should define dispatch_matmul_q4 method"
7888        );
7889    }
7890
7891    #[test]
7892    fn q4_model_loads_raw_bytes_not_dequantized() {
7893        let config = minimal_q4_config();
7894        let model_rs = generate_model_rs(&config).unwrap();
7895
7896        // Should NOT contain dequantization code
7897        assert!(
7898            !model_rs.contains("f16_to_f32"),
7899            "Q4_0 model should not dequantize weights to f32"
7900        );
7901
7902        // Should load raw Q4_0 bytes directly
7903        assert!(
7904            model_rs.contains("total_raw as u64"),
7905            "Q4_0 model should load raw bytes into Metal buffer"
7906        );
7907    }
7908
7909    #[test]
7910    fn q4_model_norms_stay_f32() {
7911        let config = minimal_q4_config();
7912        let model_rs = generate_model_rs(&config).unwrap();
7913
7914        assert!(
7915            model_rs.contains("let attn_norm = next_f32_buffer"),
7916            "attn_norm should use f32 buffer even for Q4_0 models"
7917        );
7918        assert!(
7919            model_rs.contains("let ffn_norm = next_f32_buffer"),
7920            "ffn_norm should use f32 buffer even for Q4_0 models"
7921        );
7922        assert!(
7923            model_rs.contains("let norm_buf = next_f32_buffer"),
7924            "final norm should use f32 buffer even for Q4_0 models"
7925        );
7926    }
7927
7928    #[test]
7929    fn q4_model_uses_fused_weight_loading() {
7930        let config = minimal_q4_config();
7931        let model_rs = generate_model_rs(&config).unwrap();
7932
7933        assert!(
7934            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7935            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7936        );
7937        assert!(
7938            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7939            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7940        );
7941        assert!(
7942            model_rs.contains("let o_weight = next_q4_buffer"),
7943            "Q4_0 model should use next_q4_buffer for O weight"
7944        );
7945        assert!(
7946            model_rs.contains("let down_weight = next_q4_buffer"),
7947            "Q4_0 model should use next_q4_buffer for down weight"
7948        );
7949    }
7950
7951    #[test]
7952    fn attention_flash_batch_kernel_exists() {
7953        // The flash kernel is still wired into the library (pipeline, kernel
7954        // source).  Dispatch is currently routed to the legacy path pending a
7955        // fix for a numerical issue discovered after the prompt-chunking fix.
7956        let config = minimal_config();
7957        let model_rs = generate_model_rs(&config).unwrap();
7958        let shaders = generate_metal_shaders(&config);
7959
7960        assert!(
7961            shaders.contains("kernel void attention_flash_batch"),
7962            "shaders.metal must still contain the attention_flash_batch kernel"
7963        );
7964        assert!(
7965            shaders.contains("FLASH_K_TILE"),
7966            "flash kernel must tile K/V with a FLASH_K_TILE constant"
7967        );
7968        assert!(
7969            model_rs.contains("attention_flash_batch_pipeline"),
7970            "MetalModel must register the flash attention pipeline"
7971        );
7972    }
7973
7974    #[test]
7975    fn decode_uses_fused_rope_and_kv_copy() {
7976        // v0.7.2: single-token decode reuses `rope_qk_batch` (M=1) and
7977        // `copy_kv_both_batch` (M=1) instead of calling the separate
7978        // rope_offset + copy_from_offset_f16 pairs, saving 2 dispatches +
7979        // barriers per layer.
7980        let config = minimal_config();
7981        let model_rs = generate_model_rs(&config).unwrap();
7982
7983        // forward() and forward_profile() each have a decode loop.
7984        // Locate `pub fn forward(` and `pub fn forward_profile(` and verify
7985        // the fused dispatches appear between the function header and the
7986        // start of forward_prefill (which also uses these fused helpers).
7987        let forward_start = model_rs.find("pub fn forward(").expect("forward missing");
7988        let forward_end = model_rs[forward_start..]
7989            .find("pub fn forward_prefill(")
7990            .expect("forward_prefill missing");
7991        let forward_body = &model_rs[forward_start..forward_start + forward_end];
7992
7993        assert!(
7994            forward_body.contains("dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1"),
7995            "single-token forward must use fused rope_qk_batch with M=1"
7996        );
7997        assert!(
7998            forward_body.contains("dispatch_copy_kv_both_batch(&enc, &self.qkv_buf,"),
7999            "single-token forward must use fused copy_kv_both_batch"
8000        );
8001        assert!(
8002            !forward_body.contains("dispatch_copy_from_offset_f16"),
8003            "decode path should no longer call the per-K/per-V copy"
8004        );
8005    }
8006
8007    #[test]
8008    fn kv_cache_stored_as_f16() {
8009        // v0.7.1: KV cache is f16 (2 bytes/element) instead of f32 to halve
8010        // attention memory bandwidth and KV RAM footprint.  All attention
8011        // kernels must read K/V as `device const half*`; copy kernels must
8012        // write half; the f32->f16 copy kernel must be wired for single-
8013        // token decode KV writes.
8014        let config = minimal_config();
8015        let shaders = generate_metal_shaders(&config);
8016        let model_rs = generate_model_rs(&config).unwrap();
8017
8018        for kernel in [
8019            "kernel void attention(",
8020            "kernel void attention_batch(",
8021            "kernel void attention_flash_batch(",
8022            "kernel void attention_mma_flash_batch(",
8023        ] {
8024            let start = shaders
8025                .find(kernel)
8026                .unwrap_or_else(|| panic!("kernel {kernel} missing"));
8027            // Slice through the signature block — scan for "){" which closes
8028            // the parameter list at the top of every kernel's body.
8029            let sig_end = shaders[start..]
8030                .find("){")
8031                .unwrap_or_else(|| shaders[start..].find(") {").unwrap());
8032            let sig = &shaders[start..start + sig_end];
8033            assert!(
8034                sig.contains("device const half*")
8035                    && sig.contains("k_cache")
8036                    && sig.contains("v_cache"),
8037                "{kernel} must read k_cache/v_cache as `device const half*`"
8038            );
8039            assert!(
8040                !sig.contains("device const float* k_cache")
8041                    && !sig.contains("device const float*  k_cache"),
8042                "{kernel} still reads k_cache as float"
8043            );
8044        }
8045
8046        assert!(
8047            shaders.contains("kernel void copy_f32_to_f16_offset"),
8048            "f32->f16 copy kernel must be present for single-token decode KV writes"
8049        );
8050        assert!(
8051            model_rs.contains("dispatch_copy_from_offset_f16"),
8052            "single-token decode must dispatch the f32->f16 KV copy"
8053        );
8054    }
8055
8056    #[test]
8057    fn decode_attention_uses_half4_vectorized_loads() {
8058        // v0.7.3: the decode `kernel void attention` vectorizes both the
8059        // Q·K^T dot product and the scores·V weighted sum via half4 loads
8060        // (and float4 stores on output), mirroring the prefill path.
8061        let config = minimal_config();
8062        let shaders = generate_metal_shaders(&config);
8063
8064        let start = shaders
8065            .find("kernel void attention(")
8066            .expect("decode attention kernel missing");
8067        // Slice just the decode kernel body — stop at the next kernel.
8068        let end_rel = shaders[start + 1..]
8069            .find("kernel void ")
8070            .expect("next kernel missing");
8071        let body = &shaders[start..start + 1 + end_rel];
8072
8073        assert!(
8074            body.contains("device const half4*"),
8075            "decode attention must half4-load K/V"
8076        );
8077        assert!(
8078            body.contains("device const float4*"),
8079            "decode attention must float4-load Q"
8080        );
8081        assert!(
8082            body.contains("device float4*"),
8083            "decode attention must float4-store output"
8084        );
8085        assert!(
8086            body.contains("head_dim4"),
8087            "decode attention must iterate head_dim in chunks of 4"
8088        );
8089    }
8090
8091    #[test]
8092    fn attention_mma_flash_batch_kernel_wired() {
8093        // MMA-accelerated flash attention (issue #212).  Default-on in v0.7.0
8094        // when HEAD_DIM ≤ 128 and num_tokens ≥ 8.  FORGE_MMA_ATTN=0 opts out.
8095        let config = minimal_config();
8096        let model_rs = generate_model_rs(&config).unwrap();
8097        let shaders = generate_metal_shaders(&config);
8098
8099        assert!(
8100            shaders.contains("kernel void attention_mma_flash_batch"),
8101            "shaders.metal must contain the MMA flash kernel"
8102        );
8103        assert!(
8104            shaders.contains("FLASH_MMA_Q_BLOCK"),
8105            "MMA flash kernel must define Q_BLOCK tiling constant"
8106        );
8107        assert!(
8108            shaders.contains("simdgroup_multiply_accumulate"),
8109            "MMA flash kernel must use hardware MMA"
8110        );
8111        assert!(
8112            model_rs.contains("attention_mma_flash_batch_pipeline"),
8113            "MetalModel must register the MMA flash pipeline"
8114        );
8115        assert!(
8116            model_rs.contains("mma_opt_out"),
8117            "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
8118        );
8119        assert!(
8120            model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
8121            "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
8122        );
8123    }
8124
8125    #[test]
8126    fn forward_prefill_batch_chunks_by_max_batch_size() {
8127        // Regression: prior to v0.6.4 forward_prefill_batch truncated prompts
8128        // longer than MAX_BATCH_SIZE (512) tokens via `.min(MAX_BATCH_SIZE)`,
8129        // silently dropping the middle of long prompts.  Must now loop over
8130        // MAX_BATCH_SIZE-sized chunks and carry KV-cache state across them.
8131        let config = minimal_config();
8132        let model_rs = generate_model_rs(&config).unwrap();
8133        assert!(
8134            model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
8135            "forward_prefill_batch must chunk long prompts"
8136        );
8137        assert!(
8138            !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
8139            "the old truncation path must be gone"
8140        );
8141    }
8142
8143    #[test]
8144    fn qwen2_qkv_bias_wired_through_metal_codegen() {
8145        // Issue #210: the pre-v0.6.2 Metal codegen had zero handling for
8146        // qkv_bias.  Verify that a Qwen2-style config emits the bias buffer,
8147        // loader, pipeline, and dispatch call in the expected places.
8148        let config = ModelConfig {
8149            architecture: Architecture::Qwen2,
8150            qkv_bias: true,
8151            ..minimal_config()
8152        };
8153        let model_rs = generate_model_rs(&config).unwrap();
8154
8155        assert!(
8156            model_rs.contains("qkv_bias: Buffer"),
8157            "Qwen2 LayerBuffers must declare qkv_bias field"
8158        );
8159        assert!(
8160            model_rs.contains("let qkv_bias = next_f32_buffer"),
8161            "Qwen2 layer init must load the bias from the weight blob"
8162        );
8163        assert!(
8164            model_rs.contains("add_bias_batch_pipeline"),
8165            "Qwen2 model struct must include the add_bias_batch_pipeline"
8166        );
8167        assert!(
8168            model_rs.contains("fn dispatch_add_bias_batch"),
8169            "Qwen2 codegen must emit dispatch_add_bias_batch helper"
8170        );
8171        assert!(
8172            model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
8173            "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
8174        );
8175        assert!(
8176            model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
8177            "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
8178        );
8179
8180        // The add_bias_batch MSL kernel must be in the shader source.
8181        let shaders = generate_metal_shaders(&config);
8182        assert!(
8183            shaders.contains("kernel void add_bias_batch"),
8184            "shaders.metal must contain the add_bias_batch kernel"
8185        );
8186    }
8187
8188    #[test]
8189    fn llama_does_not_emit_qkv_bias_machinery() {
8190        // Negative test: non-Qwen2 models must NOT carry the bias dispatch,
8191        // buffer, or pipeline — keeps generated code lean for Llama/Phi/etc.
8192        let config = minimal_config();
8193        assert!(!config.qkv_bias);
8194        let model_rs = generate_model_rs(&config).unwrap();
8195        assert!(
8196            !model_rs.contains("qkv_bias: Buffer"),
8197            "Llama must not have qkv_bias field"
8198        );
8199        assert!(
8200            !model_rs.contains("add_bias_batch_pipeline"),
8201            "Llama must not pull in add_bias_batch_pipeline"
8202        );
8203        assert!(
8204            !model_rs.contains("dispatch_add_bias_batch"),
8205            "Llama must not call dispatch_add_bias_batch"
8206        );
8207    }
8208
8209    #[test]
8210    fn q4_dispatch_helper_takes_compute_encoder_ref() {
8211        let config = minimal_q4_config();
8212        let model_rs = generate_model_rs(&config).unwrap();
8213
8214        assert!(
8215            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8216            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8217        );
8218    }
8219
8220    #[test]
8221    fn f32_model_does_not_use_q4_dispatch() {
8222        let config = minimal_config();
8223        let model_rs = generate_model_rs(&config).unwrap();
8224
8225        let forward_start = model_rs
8226            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8227            .unwrap();
8228        let forward_body = &model_rs[forward_start..];
8229        let forward_end = forward_body
8230            .find("\n    fn dispatch_")
8231            .unwrap_or(forward_body.len());
8232        let forward_code = &forward_body[..forward_end];
8233
8234        assert!(
8235            !forward_code.contains("dispatch_matmul_q4"),
8236            "f32 model forward should not use dispatch_matmul_q4"
8237        );
8238    }
8239
8240    #[test]
8241    fn q4_model_lm_head_uses_q4_buffer() {
8242        let config = minimal_q4_config();
8243        let model_rs = generate_model_rs(&config).unwrap();
8244
8245        assert!(
8246            model_rs.contains("let lm_head_buf = next_q4_buffer"),
8247            "Q4_0 model should use next_q4_buffer for lm_head"
8248        );
8249    }
8250
8251    #[test]
8252    fn vec_tile_size_matches_model_dimensions() {
8253        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
8254        let small = minimal_config();
8255        let shaders_small = generate_metal_shaders(&small);
8256        assert!(
8257            shaders_small.contains("vec_tile[128]"),
8258            "vec_tile should be sized to max(hidden, intermediate) = 128"
8259        );
8260
8261        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
8262        let mut large = minimal_config();
8263        large.hidden_size = 2048;
8264        large.intermediate_size = 8192;
8265        let shaders_large = generate_metal_shaders(&large);
8266        assert!(
8267            shaders_large.contains("vec_tile[8192]"),
8268            "vec_tile should be 8192 for models with intermediate=8192"
8269        );
8270        assert!(
8271            !shaders_large.contains("vec_tile[4096]"),
8272            "vec_tile should NOT be hardcoded to 4096"
8273        );
8274    }
8275
8276    #[test]
8277    fn generated_cargo_toml_has_server_deps() {
8278        let toml = generate_cargo_toml("my-model");
8279        assert!(
8280            toml.contains("tiny_http"),
8281            "Cargo.toml should depend on tiny_http"
8282        );
8283        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8284        assert!(
8285            toml.contains("serde_json"),
8286            "Cargo.toml should depend on serde_json"
8287        );
8288    }
8289
8290    #[test]
8291    fn generated_main_rs_has_serve_mode() {
8292        let config = minimal_config();
8293        let main_rs = generate_main_rs("test-model", &config).unwrap();
8294
8295        assert!(
8296            main_rs.contains("--serve"),
8297            "main.rs should parse --serve flag"
8298        );
8299        assert!(
8300            main_rs.contains("--port"),
8301            "main.rs should parse --port flag"
8302        );
8303        assert!(
8304            main_rs.contains("fn serve("),
8305            "main.rs should define serve function"
8306        );
8307        assert!(
8308            main_rs.contains("tiny_http::Server::http"),
8309            "main.rs should create tiny_http server"
8310        );
8311    }
8312
8313    #[test]
8314    fn generated_main_rs_has_chat_completions_endpoint() {
8315        let config = minimal_config();
8316        let main_rs = generate_main_rs("test-model", &config).unwrap();
8317
8318        assert!(
8319            main_rs.contains("/v1/chat/completions"),
8320            "main.rs should handle /v1/chat/completions endpoint"
8321        );
8322        assert!(
8323            main_rs.contains("/v1/models"),
8324            "main.rs should handle /v1/models endpoint"
8325        );
8326        assert!(
8327            main_rs.contains("/health"),
8328            "main.rs should handle /health endpoint"
8329        );
8330    }
8331
8332    #[test]
8333    fn generated_main_rs_has_sse_streaming() {
8334        let config = minimal_config();
8335        let main_rs = generate_main_rs("test-model", &config).unwrap();
8336
8337        assert!(
8338            main_rs.contains("text/event-stream"),
8339            "main.rs should set SSE content type for streaming"
8340        );
8341        assert!(
8342            main_rs.contains("chat.completion.chunk"),
8343            "main.rs should emit SSE chunks"
8344        );
8345        assert!(
8346            main_rs.contains("[DONE]"),
8347            "main.rs should emit [DONE] sentinel"
8348        );
8349    }
8350
8351    #[test]
8352    fn generated_main_rs_has_chat_message_formatting() {
8353        let config = minimal_config();
8354        let main_rs = generate_main_rs("test-model", &config).unwrap();
8355
8356        assert!(
8357            main_rs.contains("fn format_chat_messages"),
8358            "main.rs should define format_chat_messages function"
8359        );
8360        assert!(
8361            main_rs.contains("<|im_start|>"),
8362            "main.rs should use ChatML format"
8363        );
8364        assert!(
8365            main_rs.contains("<|im_end|>"),
8366            "main.rs should use ChatML format"
8367        );
8368    }
8369
8370    #[test]
8371    fn generated_main_rs_has_request_types() {
8372        let config = minimal_config();
8373        let main_rs = generate_main_rs("test-model", &config).unwrap();
8374
8375        assert!(
8376            main_rs.contains("struct ChatRequest"),
8377            "main.rs should define ChatRequest struct"
8378        );
8379        assert!(
8380            main_rs.contains("struct ChatMessage"),
8381            "main.rs should define ChatMessage struct"
8382        );
8383        assert!(
8384            main_rs.contains("Deserialize"),
8385            "main.rs should derive Deserialize for request types"
8386        );
8387    }
8388
8389    #[test]
8390    fn generated_model_has_reset_method() {
8391        let config = minimal_config();
8392        let model_rs = generate_model_rs(&config).unwrap();
8393
8394        assert!(
8395            model_rs.contains("pub fn reset(&mut self)"),
8396            "model.rs should have a reset() method for multi-request serving"
8397        );
8398        assert!(
8399            model_rs.contains("self.pos = 0"),
8400            "reset() should reset position to 0"
8401        );
8402    }
8403
8404    #[test]
8405    fn generated_main_rs_cli_mode_still_works() {
8406        let config = minimal_config();
8407        let main_rs = generate_main_rs("test-model", &config).unwrap();
8408
8409        // CLI mode should still be functional
8410        assert!(
8411            main_rs.contains("fn cli_mode("),
8412            "main.rs should define cli_mode function"
8413        );
8414        assert!(
8415            main_rs.contains("model.forward"),
8416            "main.rs should call model.forward"
8417        );
8418        assert!(
8419            main_rs.contains("model.forward_prefill"),
8420            "main.rs should call model.forward_prefill"
8421        );
8422    }
8423
8424    // ── Batched prefill tests ──────────────────────────────────────────
8425
8426    #[test]
8427    fn generated_shaders_contain_batch_kernels() {
8428        let shaders = generate_metal_shaders(&minimal_config());
8429
8430        assert!(
8431            shaders.contains("kernel void matmul_vec_batch"),
8432            "shaders should contain matmul_vec_batch kernel"
8433        );
8434        assert!(
8435            shaders.contains("kernel void matmul_vec_q8_batch"),
8436            "shaders should contain matmul_vec_q8_batch kernel"
8437        );
8438        assert!(
8439            shaders.contains("kernel void matmul_q8_gemm_batch"),
8440            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8441        );
8442        assert!(
8443            shaders.contains("kernel void matmul_vec_q4_batch"),
8444            "shaders should contain matmul_vec_q4_batch kernel"
8445        );
8446        assert!(
8447            shaders.contains("kernel void rms_norm_batch"),
8448            "shaders should contain rms_norm_batch kernel"
8449        );
8450        assert!(
8451            shaders.contains("kernel void silu_mul_fused_batch"),
8452            "shaders should contain silu_mul_fused_batch kernel"
8453        );
8454        assert!(
8455            shaders.contains("kernel void add_inplace_batch"),
8456            "shaders should contain add_inplace_batch kernel"
8457        );
8458        assert!(
8459            shaders.contains("kernel void copy_embedding_batch"),
8460            "shaders should contain copy_embedding_batch kernel"
8461        );
8462    }
8463
8464    #[test]
8465    fn generated_model_has_batch_pipelines() {
8466        let config = minimal_config();
8467        let model_rs = generate_model_rs(&config).unwrap();
8468
8469        for pipeline in &[
8470            "matmul_batch_pipeline",
8471            "matmul_q8_batch_pipeline",
8472            "matmul_q8_gemm_batch_pipeline",
8473            "matmul_q4_batch_pipeline",
8474            "rms_norm_batch_pipeline",
8475            "rope_batch_pipeline",
8476            "silu_mul_fused_batch_pipeline",
8477            "add_inplace_batch_pipeline",
8478            "copy_embedding_batch_pipeline",
8479            "attention_batch_pipeline",
8480            "copy_kv_batch_pipeline",
8481        ] {
8482            assert!(
8483                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8484                "MetalModel should have {pipeline} field"
8485            );
8486        }
8487    }
8488
8489    #[test]
8490    fn generated_model_has_batch_buffers() {
8491        let config = minimal_config();
8492        let model_rs = generate_model_rs(&config).unwrap();
8493
8494        for buf in &[
8495            "batch_hidden_buf",
8496            "batch_residual_buf",
8497            "batch_qkv_buf",
8498            "batch_attn_out_buf",
8499            "batch_attn_proj_buf",
8500            "batch_gate_up_buf",
8501            "batch_ffn_hidden_buf",
8502            "batch_ffn_out_buf",
8503            "batch_tokens_buf",
8504            "batch_positions_buf",
8505        ] {
8506            assert!(
8507                model_rs.contains(&format!("{buf}: Buffer")),
8508                "MetalModel should have {buf} field"
8509            );
8510        }
8511    }
8512
8513    #[test]
8514    fn generated_model_has_forward_prefill_batch() {
8515        let config = minimal_config();
8516        let model_rs = generate_model_rs(&config).unwrap();
8517
8518        assert!(
8519            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8520            "MetalModel should have forward_prefill_batch method"
8521        );
8522
8523        // forward_prefill should delegate to forward_prefill_batch
8524        assert!(
8525            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8526            "forward_prefill should delegate to forward_prefill_batch"
8527        );
8528    }
8529
8530    #[test]
8531    fn generated_model_has_max_batch_size_constant() {
8532        let config = minimal_config();
8533        let model_rs = generate_model_rs(&config).unwrap();
8534
8535        assert!(
8536            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8537            "model.rs should define MAX_BATCH_SIZE constant"
8538        );
8539    }
8540
8541    #[test]
8542    fn forward_prefill_batch_uses_batch_dispatch() {
8543        let config = minimal_config();
8544        let model_rs = generate_model_rs(&config).unwrap();
8545
8546        let batch_start = model_rs
8547            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8548            .unwrap();
8549        let batch_body = &model_rs[batch_start..];
8550        let batch_end = batch_body
8551            .find("\n    pub fn reset")
8552            .unwrap_or(batch_body.len());
8553        let batch_code = &batch_body[..batch_end];
8554
8555        // Should use batched dispatch methods
8556        assert!(
8557            batch_code.contains("dispatch_rms_norm_batch"),
8558            "forward_prefill_batch should use dispatch_rms_norm_batch"
8559        );
8560        assert!(
8561            batch_code.contains("dispatch_copy_embedding_batch"),
8562            "forward_prefill_batch should use dispatch_copy_embedding_batch"
8563        );
8564        assert!(
8565            batch_code.contains("dispatch_silu_mul_fused_batch"),
8566            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8567        );
8568        // Should use batched causal attention dispatch
8569        assert!(
8570            batch_code.contains("dispatch_attention_batch"),
8571            "forward_prefill_batch should use dispatch_attention_batch"
8572        );
8573        // Should use fused KV cache copy (both K and V in one dispatch)
8574        assert!(
8575            batch_code.contains("dispatch_copy_kv_both_batch"),
8576            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8577        );
8578        // Should use fused RoPE Q+K dispatch
8579        assert!(
8580            batch_code.contains("dispatch_rope_qk_batch"),
8581            "forward_prefill_batch should use dispatch_rope_qk_batch"
8582        );
8583    }
8584
8585    #[test]
8586    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8587        let config = minimal_q8_config();
8588        let model_rs = generate_model_rs(&config).unwrap();
8589
8590        let batch_start = model_rs
8591            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8592            .unwrap();
8593        let batch_body = &model_rs[batch_start..];
8594        let batch_end = batch_body
8595            .find("\n    pub fn reset")
8596            .unwrap_or(batch_body.len());
8597        let batch_code = &batch_body[..batch_end];
8598
8599        assert!(
8600            batch_code.contains("dispatch_matmul_q8_batch"),
8601            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8602        );
8603    }
8604
8605    #[test]
8606    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8607        let config = minimal_q4_config();
8608        let model_rs = generate_model_rs(&config).unwrap();
8609
8610        let batch_start = model_rs
8611            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8612            .unwrap();
8613        let batch_body = &model_rs[batch_start..];
8614        let batch_end = batch_body
8615            .find("\n    pub fn reset")
8616            .unwrap_or(batch_body.len());
8617        let batch_code = &batch_body[..batch_end];
8618
8619        assert!(
8620            batch_code.contains("dispatch_matmul_q4_batch"),
8621            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8622        );
8623    }
8624
8625    #[test]
8626    fn generated_main_rs_uses_batched_prefill() {
8627        let config = minimal_config();
8628        let main_rs = generate_main_rs("test-model", &config).unwrap();
8629
8630        assert!(
8631            main_rs.contains("forward_prefill_batch"),
8632            "main.rs should use forward_prefill_batch for prompt tokens"
8633        );
8634    }
8635
8636    #[test]
8637    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8638        let config = minimal_config();
8639        let model_rs = generate_model_rs(&config).unwrap();
8640
8641        let batch_start = model_rs
8642            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8643            .unwrap();
8644        let batch_body = &model_rs[batch_start..];
8645        let batch_end = batch_body
8646            .find("\n    pub fn reset")
8647            .unwrap_or(batch_body.len());
8648        let batch_code = &batch_body[..batch_end];
8649
8650        assert!(
8651            batch_code.contains("dispatch_matmul_batch"),
8652            "f32 forward_prefill_batch should use dispatch_matmul_batch"
8653        );
8654        // Should NOT use Q8 or Q4 batch dispatch
8655        assert!(
8656            !batch_code.contains("dispatch_matmul_q8_batch"),
8657            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8658        );
8659        assert!(
8660            !batch_code.contains("dispatch_matmul_q4_batch"),
8661            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8662        );
8663    }
8664
8665    #[test]
8666    fn forward_uses_cpu_embedding_lookup() {
8667        let config = minimal_config();
8668        let model_rs = generate_model_rs(&config).unwrap();
8669
8670        // Find just the forward() body (not forward_profile)
8671        let forward_start = model_rs
8672            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8673            .unwrap();
8674        let forward_body = &model_rs[forward_start..];
8675        let forward_end = forward_body
8676            .find("\n    pub fn forward_profile")
8677            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8678            .unwrap_or(forward_body.len());
8679        let forward_code = &forward_body[..forward_end];
8680
8681        // forward() should use CPU memcpy for embedding lookup (unified memory)
8682        assert!(
8683            forward_code.contains("embed_buf.contents()"),
8684            "forward() should access embed_buf via CPU unified memory for embedding lookup"
8685        );
8686        assert!(
8687            forward_code.contains("copy_nonoverlapping"),
8688            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8689        );
8690        // forward() should NOT use GPU dispatch for embedding
8691        assert!(
8692            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8693            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8694        );
8695    }
8696
8697    #[test]
8698    fn forward_profile_method_exists() {
8699        let config = minimal_config();
8700        let model_rs = generate_model_rs(&config).unwrap();
8701
8702        assert!(
8703            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8704            "MetalModel should have forward_profile() method"
8705        );
8706        // Profile method should print timing information
8707        assert!(
8708            model_rs.contains("[profile]"),
8709            "forward_profile() should print timing with [profile] prefix"
8710        );
8711        assert!(
8712            model_rs.contains("d_embed"),
8713            "forward_profile() should measure embedding time"
8714        );
8715        assert!(
8716            model_rs.contains("d_layers"),
8717            "forward_profile() should measure layer time"
8718        );
8719        assert!(
8720            model_rs.contains("d_logits"),
8721            "forward_profile() should measure logits time"
8722        );
8723    }
8724
8725    #[test]
8726    fn generated_cli_has_profile_flag() {
8727        let config = minimal_config();
8728        let main_rs = generate_main_rs("test-model", &config).unwrap();
8729
8730        assert!(
8731            main_rs.contains("--profile"),
8732            "CLI should support --profile flag"
8733        );
8734        assert!(
8735            main_rs.contains("forward_profile"),
8736            "CLI should call forward_profile when --profile is set"
8737        );
8738    }
8739
8740    #[test]
8741    fn generated_cli_has_thermal_yield() {
8742        let config = minimal_config();
8743        let main_rs = generate_main_rs("test-model", &config).unwrap();
8744
8745        assert!(
8746            main_rs.contains("yield_now()"),
8747            "CLI generation loop should include thread::yield_now() for thermal management"
8748        );
8749    }
8750
8751    // ── Real-world validation tests ──────────────────────────────────────
8752
8753    #[test]
8754    fn generated_forward_handles_single_token_prompt() {
8755        // With a single token (the first prompt token), forward() should work
8756        // at pos=0 where seq_len=1. The attention kernel must handle the case
8757        // where there is only one KV entry (no prefill context).
8758        let config = minimal_config();
8759        let model_rs = generate_model_rs(&config).unwrap();
8760
8761        // The forward function should accept any u32 token_id (no minimum pos guard)
8762        let forward_start = model_rs
8763            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8764            .expect("forward() must exist");
8765        let forward_body = &model_rs[forward_start..forward_start + 400];
8766
8767        // Should NOT require pos > 0 or seq_len > 1
8768        assert!(
8769            !forward_body.contains("assert!(self.pos > 0"),
8770            "forward() must accept pos=0 (first token with no prefill)"
8771        );
8772
8773        // The attention kernel should handle seq_len=1 via the pos field
8774        assert!(
8775            model_rs.contains("self.pos"),
8776            "forward() should use self.pos to track sequence position"
8777        );
8778    }
8779
8780    #[test]
8781    fn generated_reset_clears_kv_cache_position() {
8782        // After reset(), the model should be in a clean state. The pos field
8783        // must be 0 so new generation starts from scratch.
8784        let config = minimal_config();
8785        let model_rs = generate_model_rs(&config).unwrap();
8786
8787        let reset_start = model_rs
8788            .find("pub fn reset(&mut self)")
8789            .expect("reset() must exist");
8790        let reset_body = &model_rs[reset_start..reset_start + 200];
8791
8792        // Reset must zero the position counter
8793        assert!(
8794            reset_body.contains("self.pos = 0"),
8795            "reset() must set self.pos = 0"
8796        );
8797
8798        // Verify reset clears prev_cmd (double-buffering state)
8799        assert!(
8800            reset_body.contains("self.prev_cmd = None"),
8801            "reset() should clear prev_cmd for clean command buffer state"
8802        );
8803    }
8804
8805    #[test]
8806    fn generated_serve_handles_empty_messages_gracefully() {
8807        // The serve endpoint should not crash when receiving an empty messages array.
8808        // The format_chat_messages function should handle this gracefully.
8809        let config = minimal_config();
8810        let main_rs = generate_main_rs("test-model", &config).unwrap();
8811
8812        // The format_chat_messages function should exist and handle empty input
8813        let format_fn_start = main_rs
8814            .find("fn format_chat_messages")
8815            .expect("format_chat_messages must exist");
8816        let format_fn_body =
8817            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8818
8819        // It should iterate over messages (an empty slice produces an empty loop)
8820        assert!(
8821            format_fn_body.contains("for msg in messages"),
8822            "format_chat_messages should iterate over the messages slice"
8823        );
8824        // It should always append the assistant prompt suffix
8825        assert!(
8826            format_fn_body.contains("<|im_start|>assistant"),
8827            "format_chat_messages should always append assistant prompt header"
8828        );
8829
8830        // The serve function should call model.reset() before each request
8831        let serve_fn_start = main_rs
8832            .find("fn serve(")
8833            .expect("serve function must exist");
8834        let serve_fn_body = &main_rs[serve_fn_start..];
8835        assert!(
8836            serve_fn_body.contains("model.reset()"),
8837            "serve function should reset model between requests"
8838        );
8839    }
8840
8841    #[test]
8842    fn generated_model_forward_increments_pos() {
8843        // Each forward() call must increment self.pos so the next token
8844        // uses the correct RoPE position and KV cache offset.
8845        let config = minimal_config();
8846        let model_rs = generate_model_rs(&config).unwrap();
8847
8848        let forward_start = model_rs
8849            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8850            .unwrap();
8851        let forward_body = &model_rs[forward_start..];
8852        let forward_end = forward_body
8853            .find("\n    pub fn forward_profile")
8854            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8855            .or_else(|| forward_body.find("\n    fn dispatch_"))
8856            .unwrap_or(forward_body.len());
8857        let forward_code = &forward_body[..forward_end];
8858
8859        assert!(
8860            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8861            "forward() must increment self.pos after processing a token"
8862        );
8863    }
8864}