Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    // The attention-scores shared array must be at least effective_seq_len
122    // elements; anything smaller silently overflows for long prompts.  For
123    // small-context models (135M at 2K), a smaller array saves threadgroup
124    // memory and improves occupancy — so we size it precisely.
125    let attn_scores_size = config.max_seq_len.min(4096);
126    r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154    device const float* matrix [[buffer(0)]],
155    device const float* vector [[buffer(1)]],
156    device float* output       [[buffer(2)]],
157    constant uint& rows        [[buffer(3)]],
158    constant uint& cols        [[buffer(4)]],
159    uint tgid [[threadgroup_position_in_grid]],
160    uint tid [[thread_index_in_threadgroup]],
161    uint simd_lane [[thread_index_in_simdgroup]],
162    uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164    // Cooperatively load vector into threadgroup shared memory
165    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166    for (uint i = tid; i < cols; i += 256) {
167        vec_tile[i] = vector[i];
168    }
169    threadgroup_barrier(mem_flags::mem_threadgroup);
170
171    // Each simdgroup handles 8 consecutive rows
172    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173    if (row_base >= rows) return;
174
175    uint base0 = row_base * cols;
176    uint base1 = (row_base + 1) * cols;
177    uint base2 = (row_base + 2) * cols;
178    uint base3 = (row_base + 3) * cols;
179    uint base4 = (row_base + 4) * cols;
180    uint base5 = (row_base + 5) * cols;
181    uint base6 = (row_base + 6) * cols;
182    uint base7 = (row_base + 7) * cols;
183
184    // float4 vectorized accumulation across 8 rows
185    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
186    float4 sum4_0 = float4(0.0f);
187    float4 sum4_1 = float4(0.0f);
188    float4 sum4_2 = float4(0.0f);
189    float4 sum4_3 = float4(0.0f);
190    float4 sum4_4 = float4(0.0f);
191    float4 sum4_5 = float4(0.0f);
192    float4 sum4_6 = float4(0.0f);
193    float4 sum4_7 = float4(0.0f);
194
195    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205    }
206
207    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216    // Handle remaining elements (cols not divisible by 128)
217    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218        float vv = vec_tile[j];
219        sum0 += matrix[base0 + j] * vv;
220        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227    }
228
229    // Simdgroup hardware warp-level reduction
230    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235    // Only first lane writes the results
236    if (simd_lane == 0) {
237        if (row_base     < rows) output[row_base]     = sum0;
238        if (row_base + 1 < rows) output[row_base + 1] = sum1;
239        if (row_base + 2 < rows) output[row_base + 2] = sum2;
240        if (row_base + 3 < rows) output[row_base + 3] = sum3;
241        if (row_base + 4 < rows) output[row_base + 4] = sum4;
242        if (row_base + 5 < rows) output[row_base + 5] = sum5;
243        if (row_base + 6 < rows) output[row_base + 6] = sum6;
244        if (row_base + 7 < rows) output[row_base + 7] = sum7;
245    }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253    device const float* input   [[buffer(0)]],
254    device const float* weight  [[buffer(1)]],
255    device float* output        [[buffer(2)]],
256    constant uint& n            [[buffer(3)]],
257    constant float& eps         [[buffer(4)]],
258    uint tid [[thread_index_in_threadgroup]])
259{
260    // Each thread accumulates partial sum-of-squares
261    float sum_sq = 0.0f;
262    for (uint i = tid; i < n; i += 256) {
263        float v = input[i];
264        sum_sq += v * v;
265    }
266
267    // Simdgroup-level reduction (hardware warp sum)
268    sum_sq = simd_sum(sum_sq);
269
270    // Cross-simdgroup reduction via shared memory
271    threadgroup float shared[8];
272    uint simd_id = tid / 32;
273    uint simd_lane = tid % 32;
274    if (simd_lane == 0) {
275        shared[simd_id] = sum_sq;
276    }
277    threadgroup_barrier(mem_flags::mem_threadgroup);
278
279    // First thread computes final inverse RMS
280    if (tid == 0) {
281        float total = 0.0f;
282        for (uint i = 0; i < 8; i++) {
283            total += shared[i];
284        }
285        shared[0] = fast::rsqrt(total / float(n) + eps);
286    }
287    threadgroup_barrier(mem_flags::mem_threadgroup);
288
289    float inv_rms = shared[0];
290
291    // Normalize
292    for (uint i = tid; i < n; i += 256) {
293        output[i] = input[i] * inv_rms * weight[i];
294    }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301    device float* data        [[buffer(0)]],
302    constant uint& num_heads  [[buffer(1)]],
303    constant uint& head_dim   [[buffer(2)]],
304    constant uint& pos        [[buffer(3)]],
305    constant float& theta     [[buffer(4)]],
306    uint id [[thread_position_in_grid]])
307{
308    uint half_dim = head_dim / 2;
309    uint total_pairs = num_heads * half_dim;
310    if (id >= total_pairs) return;
311
312    uint h = id / half_dim;
313    uint i = id % half_dim;
314    uint off = h * head_dim;
315
316    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317    float angle = float(pos) * freq;
318    float c = cos(angle);
319    float s = sin(angle);
320
321    float x0 = data[off + 2 * i];
322    float x1 = data[off + 2 * i + 1];
323    data[off + 2 * i]     = x0 * c - x1 * s;
324    data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331    device float* data       [[buffer(0)]],
332    constant uint& n         [[buffer(1)]],
333    uint tid [[thread_index_in_threadgroup]],
334    uint tg_size [[threads_per_threadgroup]])
335{
336    threadgroup float shared_val[256];
337
338    // Pass 1: find max
339    float local_max = -INFINITY;
340    for (uint i = tid; i < n; i += tg_size) {
341        local_max = max(local_max, data[i]);
342    }
343    shared_val[tid] = local_max;
344    threadgroup_barrier(mem_flags::mem_threadgroup);
345
346    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347        if (tid < stride) {
348            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349        }
350        threadgroup_barrier(mem_flags::mem_threadgroup);
351    }
352    float max_val = shared_val[0];
353    threadgroup_barrier(mem_flags::mem_threadgroup);
354
355    // Pass 2: exp and sum
356    float local_sum = 0.0f;
357    for (uint i = tid; i < n; i += tg_size) {
358        float e = fast::exp(data[i] - max_val);
359        data[i] = e;
360        local_sum += e;
361    }
362    shared_val[tid] = local_sum;
363    threadgroup_barrier(mem_flags::mem_threadgroup);
364
365    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366        if (tid < stride) {
367            shared_val[tid] += shared_val[tid + stride];
368        }
369        threadgroup_barrier(mem_flags::mem_threadgroup);
370    }
371    float inv_sum = 1.0f / shared_val[0];
372    threadgroup_barrier(mem_flags::mem_threadgroup);
373
374    // Pass 3: normalize
375    for (uint i = tid; i < n; i += tg_size) {
376        data[i] *= inv_sum;
377    }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384    device const float* gate [[buffer(0)]],
385    device const float* up   [[buffer(1)]],
386    device float* output     [[buffer(2)]],
387    constant uint& n         [[buffer(3)]],
388    uint id [[thread_position_in_grid]])
389{
390    if (id >= n) return;
391    float g = gate[id];
392    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397//   gate = gate_up[0..n], up = gate_up[n..2*n]
398//   output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400    device const float* gate_up [[buffer(0)]],
401    device float* output        [[buffer(1)]],
402    constant uint& n            [[buffer(2)]],
403    uint id [[thread_position_in_grid]])
404{
405    if (id >= n) return;
406    float g = gate_up[id];
407    float u = gate_up[n + id];
408    output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414    device const float* a  [[buffer(0)]],
415    device const float* b  [[buffer(1)]],
416    device float* output   [[buffer(2)]],
417    constant uint& n       [[buffer(3)]],
418    uint id [[thread_position_in_grid]])
419{
420    if (id >= n) return;
421    output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428    device const float* src [[buffer(0)]],
429    device float* dst       [[buffer(1)]],
430    constant uint& count    [[buffer(2)]],
431    uint id [[thread_position_in_grid]])
432{
433    if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440    device const float* src     [[buffer(0)]],
441    device float* dst           [[buffer(1)]],
442    constant uint& src_offset   [[buffer(2)]],  // in floats
443    constant uint& count        [[buffer(3)]],
444    uint id [[thread_position_in_grid]])
445{
446    if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── copy_f32_to_f16_offset ──────────────────────────────────────────────
450// Copy f32 elements from src into a half-typed dst, converting to half on
451// write.  Used by the single-token decode path to append a new K/V vector
452// to the f16 KV cache.  Byte offsets into src/dst are supplied via the
453// Metal buffer binding offsets — no in-kernel offsets needed.
454kernel void copy_f32_to_f16_offset(
455    device const float* src     [[buffer(0)]],
456    device half* dst            [[buffer(1)]],
457    constant uint& count        [[buffer(2)]],
458    uint id [[thread_position_in_grid]])
459{
460    if (id < count) dst[id] = half(src[id]);
461}
462
463// ── add_inplace ─────────────────────────────────────────────────────────
464// In-place residual connection: a[i] += b[i]
465// Avoids a separate blit copy for residual add, reducing encoder overhead.
466kernel void add_inplace(
467    device float* a        [[buffer(0)]],
468    device const float* b  [[buffer(1)]],
469    constant uint& n       [[buffer(2)]],
470    uint id [[thread_position_in_grid]])
471{
472    if (id >= n) return;
473    a[id] += b[id];
474}
475
476// ── matmul_vec_q8 ─────────────────────────────────────────────────────
477// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
478// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
479// Operates directly on quantized weights to halve memory bandwidth vs f32,
480// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
481//
482// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
483// because int8->float conversion doubles register demand.  Fully unrolled
484// inner loop with float4 vector loads from shared memory eliminates loop
485// overhead and enables better instruction scheduling.
486// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
487constant constexpr uint Q8_ROWS_PER_SG = 4;
488constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
489
490// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
491// doubles ALU work, so the same register budget applies).
492constant constexpr uint Q4_ROWS_PER_SG = 4;
493constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
494
495kernel void matmul_vec_q8(
496    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
497    device const float* vector   [[buffer(1)]],  // f32 input
498    device float* output         [[buffer(2)]],
499    constant uint& rows          [[buffer(3)]],
500    constant uint& cols          [[buffer(4)]],  // number of elements per row
501    uint tgid [[threadgroup_position_in_grid]],
502    uint tid [[thread_index_in_threadgroup]],
503    uint simd_lane [[thread_index_in_simdgroup]],
504    uint simd_id [[simdgroup_index_in_threadgroup]])
505{
506    // Load vector into shared memory
507    threadgroup float vec_tile[VEC_TILE_SIZE];
508    for (uint i = tid; i < cols; i += 256) {
509        vec_tile[i] = vector[i];
510    }
511    threadgroup_barrier(mem_flags::mem_threadgroup);
512
513    // Each simdgroup handles 4 consecutive rows (lower register pressure)
514    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
515    if (row_base >= rows) return;
516
517    // Q8_0: each block is 34 bytes for 32 elements
518    uint blocks_per_row = cols / 32;
519    uint row_bytes = blocks_per_row * 34;
520
521    // Pointers to each row's Q8_0 data
522    device const uchar* r0 = matrix + row_base * row_bytes;
523    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
524    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
525    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
526
527    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
528
529    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
530        uint bb = blk * 34;
531        uint vb = blk * 32;
532
533        // Prefetch all 4 scales
534        float sc0 = float(*(device const half*)(r0 + bb));
535        float sc1 = float(*(device const half*)(r1 + bb));
536        float sc2 = float(*(device const half*)(r2 + bb));
537        float sc3 = float(*(device const half*)(r3 + bb));
538
539        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
540        // Q8_0 block layout where the int8 data starts at offset +2 from a
541        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
542        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
543        // reduction in memory transactions. Metal's char16/packed_char16 are
544        // reserved types and packed_*int4 require >=4-byte alignment which
545        // this layout does not provide, so packed_short4 is the widest valid
546        // vectorized load.
547        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
548        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
549        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
550        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
551
552        // Load all 8 float4 vector values for this 32-element block from shared memory
553        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
554        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
555        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
556        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
557        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
558        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
559        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
560        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
561
562        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
563        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
564        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
565            short4 _s = short4(SHORT4); \
566            char2 _a = as_type<char2>(_s.x); \
567            char2 _b = as_type<char2>(_s.y); \
568            char2 _c = as_type<char2>(_s.z); \
569            char2 _d = as_type<char2>(_s.w); \
570            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
571            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
572        }
573
574        float4 f0, f1;
575        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
576
577        // Row 0: 4 short4 loads cover 32 int8 weights
578        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
579        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
580        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
581        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
582
583        // Row 1
584        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
585        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
586        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
587        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
588
589        // Row 2
590        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
591        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
592        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
593        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
594
595        // Row 3
596        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
597        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
598        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
599        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
600
601        #undef Q8_UNPACK8
602
603        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
604    }
605
606    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
607    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
608
609    if (simd_lane == 0) {
610        if (row_base     < rows) output[row_base]     = sum0;
611        if (row_base + 1 < rows) output[row_base + 1] = sum1;
612        if (row_base + 2 < rows) output[row_base + 2] = sum2;
613        if (row_base + 3 < rows) output[row_base + 3] = sum3;
614    }
615}
616
617// ── matmul_vec_q4 ─────────────────────────────────────────────────────
618// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
619// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
620// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
621// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
622//
623// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
624// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
625kernel void matmul_vec_q4(
626    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
627    device const float* vector   [[buffer(1)]],  // f32 input
628    device float* output         [[buffer(2)]],
629    constant uint& rows          [[buffer(3)]],
630    constant uint& cols          [[buffer(4)]],  // number of elements per row
631    uint tgid [[threadgroup_position_in_grid]],
632    uint tid [[thread_index_in_threadgroup]],
633    uint simd_lane [[thread_index_in_simdgroup]],
634    uint simd_id [[simdgroup_index_in_threadgroup]])
635{
636    // Load vector into shared memory
637    threadgroup float vec_tile[VEC_TILE_SIZE];
638    for (uint i = tid; i < cols; i += 256) {
639        vec_tile[i] = vector[i];
640    }
641    threadgroup_barrier(mem_flags::mem_threadgroup);
642
643    // Each simdgroup handles 4 consecutive rows
644    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
645    if (row_base >= rows) return;
646
647    // Q4_0: each block is 18 bytes for 32 elements
648    uint blocks_per_row = cols / 32;
649    uint row_bytes = blocks_per_row * 18;
650
651    // Pointers to each row's Q4_0 data
652    device const uchar* r0 = matrix + row_base * row_bytes;
653    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
654    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
655    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
656
657    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
658
659    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
660        uint bb = blk * 18;
661        uint vb = blk * 32;
662
663        // Prefetch all 4 scales
664        float sc0 = float(*(device const half*)(r0 + bb));
665        float sc1 = float(*(device const half*)(r1 + bb));
666        float sc2 = float(*(device const half*)(r2 + bb));
667        float sc3 = float(*(device const half*)(r3 + bb));
668
669        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
670        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
671        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
672        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
673        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
674
675        // Load 8 float4 vector values for 32 elements from shared memory
676        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
677        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
678        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
679        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
680        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
681        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
682        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
683        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
684        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
685
686        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
687        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
688        float bd0=0, bd1=0, bd2=0, bd3=0;
689        uchar4 b;
690
691        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
692        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
693                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
694                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
695                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
696        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
697                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
698                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
699                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
700        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
701                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
702                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
703                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
704        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
705                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
706                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
707                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
708
709        // Row 1
710        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
711                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
712                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
713                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
714        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
715                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
716                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
717                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
718        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
719                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
720                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
721                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
722        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
723                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
724                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
725                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
726
727        // Row 2
728        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
729                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
730                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
731                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
732        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
733                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
734                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
735                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
736        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
737                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
738                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
739                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
740        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
741                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
742                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
743                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
744
745        // Row 3
746        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
747                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
748                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
749                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
750        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
751                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
752                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
753                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
754        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
755                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
756                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
757                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
758        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
759                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
760                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
761                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
762
763        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
764    }
765
766    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
767    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
768
769    if (simd_lane == 0) {
770        if (row_base     < rows) output[row_base]     = sum0;
771        if (row_base + 1 < rows) output[row_base + 1] = sum1;
772        if (row_base + 2 < rows) output[row_base + 2] = sum2;
773        if (row_base + 3 < rows) output[row_base + 3] = sum3;
774    }
775}
776
777// ── attention ───────────────────────────────────────────────────────────
778// Single-query attention with simdgroup cooperative reductions.
779// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
780// with simd_max/simd_sum reductions, then weighted sum of V.
781// Each threadgroup handles one head with 256 threads (8 simdgroups).
782//
783// Buffers:
784//   q:       [num_heads * head_dim]       current query
785//   k_cache: [max_seq_len * num_kv_heads * head_dim]
786//   v_cache: [max_seq_len * num_kv_heads * head_dim]
787//   output:  [num_heads * head_dim]
788kernel void attention(
789    device const float* q        [[buffer(0)]],
790    device const half*  k_cache  [[buffer(1)]],
791    device const half*  v_cache  [[buffer(2)]],
792    device float* output         [[buffer(3)]],
793    constant uint& seq_len       [[buffer(4)]],
794    constant uint& num_heads     [[buffer(5)]],
795    constant uint& num_kv_heads  [[buffer(6)]],
796    constant uint& head_dim      [[buffer(7)]],
797    uint tgid [[threadgroup_position_in_grid]],
798    uint tid [[thread_index_in_threadgroup]],
799    uint simd_lane [[thread_index_in_simdgroup]],
800    uint simd_id [[simdgroup_index_in_threadgroup]])
801{
802    uint head = tgid;
803    if (head >= num_heads) return;
804    uint kv_head = head / (num_heads / num_kv_heads);
805
806    uint q_off = head * head_dim;
807
808    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
809    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
810    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
811    // most generation steps have short effective context.
812    threadgroup float scores[ATTN_SCORES_SIZE];  // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
813
814    for (uint s = simd_id; s < seq_len; s += 8) {
815        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
816        float dot = 0.0;
817        for (uint d = simd_lane; d < head_dim; d += 32) {
818            dot += q[q_off + d] * k_cache[k_off + d];
819        }
820        dot = simd_sum(dot);
821        if (simd_lane == 0) {
822            scores[s] = dot * fast::rsqrt(float(head_dim));
823        }
824    }
825    threadgroup_barrier(mem_flags::mem_threadgroup);
826
827    // Step 2: Softmax over scores (cooperative)
828    // Find max
829    float local_max = -INFINITY;
830    for (uint s = tid; s < seq_len; s += 256) {
831        local_max = max(local_max, scores[s]);
832    }
833    local_max = simd_max(local_max);
834    threadgroup float shared_max[8];
835    if (simd_lane == 0) shared_max[simd_id] = local_max;
836    threadgroup_barrier(mem_flags::mem_threadgroup);
837    if (tid == 0) {
838        float m = shared_max[0];
839        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
840        shared_max[0] = m;
841    }
842    threadgroup_barrier(mem_flags::mem_threadgroup);
843    float max_val = shared_max[0];
844
845    // Exp and sum
846    float local_sum = 0.0;
847    for (uint s = tid; s < seq_len; s += 256) {
848        scores[s] = fast::exp(scores[s] - max_val);
849        local_sum += scores[s];
850    }
851    local_sum = simd_sum(local_sum);
852    threadgroup float shared_sum[8];
853    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
854    threadgroup_barrier(mem_flags::mem_threadgroup);
855    if (tid == 0) {
856        float total = 0.0;
857        for (uint i = 0; i < 8; i++) total += shared_sum[i];
858        shared_sum[0] = 1.0 / total;
859    }
860    threadgroup_barrier(mem_flags::mem_threadgroup);
861    float inv_sum = shared_sum[0];
862
863    for (uint s = tid; s < seq_len; s += 256) {
864        scores[s] *= inv_sum;
865    }
866    threadgroup_barrier(mem_flags::mem_threadgroup);
867
868    // Step 3: Weighted sum of V: output = scores · V
869    // Each thread handles a range of head_dim dimensions.
870    // Process 4 sequence positions at a time for better ILP and reduced
871    // loop overhead (float4 score gather, 4 V loads per iteration).
872    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
873    uint v_stride = num_kv_heads * head_dim;
874    for (uint d = tid; d < head_dim; d += 256) {
875        float acc = 0.0;
876        uint v_base = kv_head * head_dim + d;
877        for (uint s = 0; s < seq_len4; s += 4) {
878            float sc0 = scores[s];
879            float sc1 = scores[s + 1];
880            float sc2 = scores[s + 2];
881            float sc3 = scores[s + 3];
882            acc += sc0 * v_cache[s * v_stride + v_base]
883                 + sc1 * v_cache[(s+1) * v_stride + v_base]
884                 + sc2 * v_cache[(s+2) * v_stride + v_base]
885                 + sc3 * v_cache[(s+3) * v_stride + v_base];
886        }
887        for (uint s = seq_len4; s < seq_len; s++) {
888            acc += scores[s] * v_cache[s * v_stride + v_base];
889        }
890        output[q_off + d] = acc;
891    }
892}
893
894// ── Batched prefill kernels ────────────────────────────────────────────
895// These kernels process M input vectors against the same weight matrix
896// in a single dispatch, converting mat-vec into mat-mat for better GPU
897// utilization during prompt prefill.
898
899// ── rms_norm_batch ─────────────────────────────────────────────────────
900// RMS normalization for a batch of vectors.
901// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
902// Grid: M threadgroups (one per token).
903kernel void rms_norm_batch(
904    device const float* input   [[buffer(0)]],  // [M, n]
905    device const float* weight  [[buffer(1)]],  // [n]
906    device float* output        [[buffer(2)]],  // [M, n]
907    constant uint& n            [[buffer(3)]],
908    constant float& eps         [[buffer(4)]],
909    constant uint& num_tokens   [[buffer(5)]],
910    uint tgid [[threadgroup_position_in_grid]],
911    uint tid [[thread_index_in_threadgroup]])
912{
913    if (tgid >= num_tokens) return;
914
915    uint base = tgid * n;
916
917    float sum_sq = 0.0f;
918    for (uint i = tid; i < n; i += 256) {
919        float v = input[base + i];
920        sum_sq += v * v;
921    }
922
923    sum_sq = simd_sum(sum_sq);
924
925    threadgroup float shared[8];
926    uint simd_id = tid / 32;
927    uint simd_lane = tid % 32;
928    if (simd_lane == 0) {
929        shared[simd_id] = sum_sq;
930    }
931    threadgroup_barrier(mem_flags::mem_threadgroup);
932
933    if (tid == 0) {
934        float total = 0.0f;
935        for (uint i = 0; i < 8; i++) {
936            total += shared[i];
937        }
938        shared[0] = fast::rsqrt(total / float(n) + eps);
939    }
940    threadgroup_barrier(mem_flags::mem_threadgroup);
941
942    float inv_rms = shared[0];
943
944    for (uint i = tid; i < n; i += 256) {
945        output[base + i] = input[base + i] * inv_rms * weight[i];
946    }
947}
948
949// ── rope_batch ─────────────────────────────────────────────────────────
950// Rotary Position Embedding for a batch of vectors with different positions.
951// data layout: [M, num_heads * head_dim], positions: [M]
952// Each thread handles one (token, head, pair) combination.
953kernel void rope_batch(
954    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
955    constant uint& num_heads     [[buffer(1)]],
956    constant uint& head_dim      [[buffer(2)]],
957    device const uint* positions  [[buffer(3)]],  // [M] position per token
958    constant float& theta        [[buffer(4)]],
959    constant uint& num_tokens    [[buffer(5)]],
960    uint id [[thread_position_in_grid]])
961{
962    uint half_dim = head_dim / 2;
963    uint pairs_per_token = num_heads * half_dim;
964    uint total = num_tokens * pairs_per_token;
965    if (id >= total) return;
966
967    uint token = id / pairs_per_token;
968    uint rem = id % pairs_per_token;
969    uint h = rem / half_dim;
970    uint i = rem % half_dim;
971    uint off = token * (num_heads * head_dim) + h * head_dim;
972
973    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
974    float angle = float(positions[token]) * freq;
975    float c = cos(angle);
976    float s = sin(angle);
977
978    float x0 = data[off + 2 * i];
979    float x1 = data[off + 2 * i + 1];
980    data[off + 2 * i]     = x0 * c - x1 * s;
981    data[off + 2 * i + 1] = x0 * s + x1 * c;
982}
983
984// ── silu_mul_fused_batch ───────────────────────────────────────────────
985// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
986// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
987kernel void silu_mul_fused_batch(
988    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
989    device float* output        [[buffer(1)]],  // [M, n]
990    constant uint& n            [[buffer(2)]],
991    constant uint& num_tokens   [[buffer(3)]],
992    uint id [[thread_position_in_grid]])
993{
994    uint total = num_tokens * n;
995    if (id >= total) return;
996    uint token = id / n;
997    uint i = id % n;
998    uint gu_base = token * 2 * n;
999    float g = gate_up[gu_base + i];
1000    float u = gate_up[gu_base + n + i];
1001    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
1002}
1003
1004// ── add_inplace_batch ──────────────────────────────────────────────────
1005// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
1006kernel void add_inplace_batch(
1007    device float* a        [[buffer(0)]],  // [M * n]
1008    device const float* b  [[buffer(1)]],  // [M * n]
1009    constant uint& total   [[buffer(2)]],  // M * n
1010    uint id [[thread_position_in_grid]])
1011{
1012    if (id >= total) return;
1013    a[id] += b[id];
1014}
1015
1016// ── copy_embedding_batch ───────────────────────────────────────────────
1017// Copy M embedding rows from embedding table to a contiguous batch buffer.
1018// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1019kernel void copy_embedding_batch(
1020    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1021    device float* output        [[buffer(1)]],  // [M, dim]
1022    device const uint* tokens   [[buffer(2)]],  // [M]
1023    constant uint& dim          [[buffer(3)]],
1024    constant uint& num_tokens   [[buffer(4)]],
1025    uint id [[thread_position_in_grid]])
1026{
1027    uint total = num_tokens * dim;
1028    if (id >= total) return;
1029    uint token_idx = id / dim;
1030    uint d = id % dim;
1031    output[id] = embed[tokens[token_idx] * dim + d];
1032}
1033
1034// ── matmul_vec_batch ───────────────────────────────────────────────────
1035// Batched matrix-vector multiply: process M input vectors against the same
1036// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1037// Each threadgroup handles one (token, row_group) pair.
1038kernel void matmul_vec_batch(
1039    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1040    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1041    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1042    constant uint& num_tokens   [[buffer(3)]],  // M
1043    constant uint& rows         [[buffer(4)]],
1044    constant uint& cols         [[buffer(5)]],
1045    uint tgid [[threadgroup_position_in_grid]],
1046    uint tid [[thread_index_in_threadgroup]],
1047    uint simd_lane [[thread_index_in_simdgroup]],
1048    uint simd_id [[simdgroup_index_in_threadgroup]])
1049{
1050    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1051    uint token = tgid / row_tgs;
1052    uint tg_in_token = tgid % row_tgs;
1053    if (token >= num_tokens) return;
1054
1055    // Load this token's input vector into shared memory
1056    threadgroup float vec_tile[VEC_TILE_SIZE];
1057    device const float* input = inputs + token * cols;
1058    for (uint i = tid; i < cols; i += 256) {
1059        vec_tile[i] = input[i];
1060    }
1061    threadgroup_barrier(mem_flags::mem_threadgroup);
1062
1063    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1064    if (row_base >= rows) return;
1065
1066    uint base0 = row_base * cols;
1067    uint base1 = (row_base + 1) * cols;
1068    uint base2 = (row_base + 2) * cols;
1069    uint base3 = (row_base + 3) * cols;
1070    uint base4 = (row_base + 4) * cols;
1071    uint base5 = (row_base + 5) * cols;
1072    uint base6 = (row_base + 6) * cols;
1073    uint base7 = (row_base + 7) * cols;
1074
1075    uint cols_vec4 = cols & ~127u;
1076    float4 sum4_0 = float4(0.0f);
1077    float4 sum4_1 = float4(0.0f);
1078    float4 sum4_2 = float4(0.0f);
1079    float4 sum4_3 = float4(0.0f);
1080    float4 sum4_4 = float4(0.0f);
1081    float4 sum4_5 = float4(0.0f);
1082    float4 sum4_6 = float4(0.0f);
1083    float4 sum4_7 = float4(0.0f);
1084
1085    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1086        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1087        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1088        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1089        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1090        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1091        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1092        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1093        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1094        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1095    }
1096
1097    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1098    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1099    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1100    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1101    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1102    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1103    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1104    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1105
1106    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1107        float vv = vec_tile[j];
1108        sum0 += matrix[base0 + j] * vv;
1109        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1110        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1111        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1112        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1113        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1114        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1115        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1116    }
1117
1118    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1119    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1120    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1121    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1122
1123    device float* output = outputs + token * rows;
1124    if (simd_lane == 0) {
1125        if (row_base     < rows) output[row_base]     = sum0;
1126        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1127        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1128        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1129        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1130        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1131        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1132        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1133    }
1134}
1135
1136// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1137// Batched Q8_0 matrix-vector multiply for M input vectors.
1138// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1139kernel void matmul_vec_q8_batch(
1140    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1141    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1142    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1143    constant uint& num_tokens    [[buffer(3)]],  // M
1144    constant uint& rows          [[buffer(4)]],
1145    constant uint& cols          [[buffer(5)]],
1146    uint tgid [[threadgroup_position_in_grid]],
1147    uint tid [[thread_index_in_threadgroup]],
1148    uint simd_lane [[thread_index_in_simdgroup]],
1149    uint simd_id [[simdgroup_index_in_threadgroup]])
1150{
1151    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1152    uint token = tgid / row_tgs;
1153    uint tg_in_token = tgid % row_tgs;
1154    if (token >= num_tokens) return;
1155
1156    // Load this token's input vector into shared memory
1157    threadgroup float vec_tile[VEC_TILE_SIZE];
1158    device const float* input = inputs + token * cols;
1159    for (uint i = tid; i < cols; i += 256) {
1160        vec_tile[i] = input[i];
1161    }
1162    threadgroup_barrier(mem_flags::mem_threadgroup);
1163
1164    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1165    if (row_base >= rows) return;
1166
1167    uint blocks_per_row = cols / 32;
1168    uint row_bytes = blocks_per_row * 34;
1169
1170    device const uchar* r0 = matrix + row_base * row_bytes;
1171    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1172    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1173    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1174
1175    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1176
1177    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1178        uint bb = blk * 34;
1179        uint vb = blk * 32;
1180
1181        float sc0 = float(*(device const half*)(r0 + bb));
1182        float sc1 = float(*(device const half*)(r1 + bb));
1183        float sc2 = float(*(device const half*)(r2 + bb));
1184        float sc3 = float(*(device const half*)(r3 + bb));
1185
1186        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1187        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1188        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1189        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1190        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1191        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1192
1193        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1194        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1195        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1196        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1197        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1198        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1199        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1200        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1201
1202        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1203            short4 _s = short4(SHORT4); \
1204            char2 _a = as_type<char2>(_s.x); \
1205            char2 _b = as_type<char2>(_s.y); \
1206            char2 _c = as_type<char2>(_s.z); \
1207            char2 _d = as_type<char2>(_s.w); \
1208            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1209            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1210        }
1211
1212        float4 f0, f1;
1213        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1214
1215        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1216        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1217        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1218        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1219
1220        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1221        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1222        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1223        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1224
1225        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1226        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1227        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1228        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1229
1230        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1231        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1232        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1233        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1234
1235        #undef Q8_UNPACK8
1236
1237        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1238    }
1239
1240    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1241    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1242
1243    device float* output = outputs + token * rows;
1244    if (simd_lane == 0) {
1245        if (row_base     < rows) output[row_base]     = sum0;
1246        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1247        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1248        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1249    }
1250}
1251
1252// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1253// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1254// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1255// the Q8_0 weight blocks are fetched once from device memory and reused for
1256// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1257// per-token dispatch).
1258//
1259// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1260// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1261// with simd_sum.  Token vectors are read directly from device memory inside
1262// the block loop (not cached in shared memory) so intermediate_size up to
1263// 8192 fits without spilling threadgroup memory.
1264constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1265
1266kernel void matmul_q8_gemm_batch(
1267    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1268    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1269    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1270    constant uint& num_tokens    [[buffer(3)]],  // M
1271    constant uint& rows          [[buffer(4)]],
1272    constant uint& cols          [[buffer(5)]],
1273    uint2 tgid [[threadgroup_position_in_grid]],
1274    uint tid [[thread_index_in_threadgroup]],
1275    uint simd_lane [[thread_index_in_simdgroup]],
1276    uint simd_id [[simdgroup_index_in_threadgroup]])
1277{
1278    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1279    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1280    if (row_base >= rows || tok_base >= num_tokens) return;
1281
1282    // How many tokens in this tile are valid?
1283    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1284
1285    uint blocks_per_row = cols / 32;
1286    uint row_bytes = blocks_per_row * 34;
1287
1288    device const uchar* r0 = matrix + row_base * row_bytes;
1289    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1290    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1291    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1292
1293    // Accumulators: 4 tokens × 4 rows per simdgroup.
1294    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1295    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1296    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1297    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1298
1299    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1300        uint bb = blk * 34;
1301        uint vb = blk * 32;
1302
1303        // ── Load weight data ONCE per block (reused across all tokens) ──
1304        float sc0 = float(*(device const half*)(r0 + bb));
1305        float sc1 = float(*(device const half*)(r1 + bb));
1306        float sc2 = float(*(device const half*)(r2 + bb));
1307        float sc3 = float(*(device const half*)(r3 + bb));
1308
1309        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1310        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1311        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1312        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1313
1314        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1315            short4 _s = short4(SHORT4); \
1316            char2 _a = as_type<char2>(_s.x); \
1317            char2 _b = as_type<char2>(_s.y); \
1318            char2 _c = as_type<char2>(_s.z); \
1319            char2 _d = as_type<char2>(_s.w); \
1320            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1321            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1322        }
1323
1324        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1325        // registers for the duration of the block and are dotted against
1326        // every token's vector tile.
1327        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1328        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1329        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1330        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1331
1332        Q8_UNPACK8(d0[0], w0_0, w0_1);
1333        Q8_UNPACK8(d0[1], w0_2, w0_3);
1334        Q8_UNPACK8(d0[2], w0_4, w0_5);
1335        Q8_UNPACK8(d0[3], w0_6, w0_7);
1336
1337        Q8_UNPACK8(d1[0], w1_0, w1_1);
1338        Q8_UNPACK8(d1[1], w1_2, w1_3);
1339        Q8_UNPACK8(d1[2], w1_4, w1_5);
1340        Q8_UNPACK8(d1[3], w1_6, w1_7);
1341
1342        Q8_UNPACK8(d2[0], w2_0, w2_1);
1343        Q8_UNPACK8(d2[1], w2_2, w2_3);
1344        Q8_UNPACK8(d2[2], w2_4, w2_5);
1345        Q8_UNPACK8(d2[3], w2_6, w2_7);
1346
1347        Q8_UNPACK8(d3[0], w3_0, w3_1);
1348        Q8_UNPACK8(d3[1], w3_2, w3_3);
1349        Q8_UNPACK8(d3[2], w3_4, w3_5);
1350        Q8_UNPACK8(d3[3], w3_6, w3_7);
1351
1352        #undef Q8_UNPACK8
1353
1354        // ── For each token, read vector and accumulate against shared weights ──
1355        // Token 0 (always valid: tok_count >= 1).
1356        {
1357            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1358            float4 v0 = *(device const float4*)(a0);
1359            float4 v1 = *(device const float4*)(a0 + 4);
1360            float4 v2 = *(device const float4*)(a0 + 8);
1361            float4 v3 = *(device const float4*)(a0 + 12);
1362            float4 v4 = *(device const float4*)(a0 + 16);
1363            float4 v5 = *(device const float4*)(a0 + 20);
1364            float4 v6 = *(device const float4*)(a0 + 24);
1365            float4 v7 = *(device const float4*)(a0 + 28);
1366            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1367                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1368            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1369                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1370            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1371                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1372            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1373                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1374            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1375        }
1376        // Token 1
1377        if (tok_count > 1) {
1378            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1379            float4 v0 = *(device const float4*)(a1);
1380            float4 v1 = *(device const float4*)(a1 + 4);
1381            float4 v2 = *(device const float4*)(a1 + 8);
1382            float4 v3 = *(device const float4*)(a1 + 12);
1383            float4 v4 = *(device const float4*)(a1 + 16);
1384            float4 v5 = *(device const float4*)(a1 + 20);
1385            float4 v6 = *(device const float4*)(a1 + 24);
1386            float4 v7 = *(device const float4*)(a1 + 28);
1387            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1388                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1389            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1390                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1391            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1392                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1393            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1394                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1395            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1396        }
1397        // Token 2
1398        if (tok_count > 2) {
1399            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1400            float4 v0 = *(device const float4*)(a2);
1401            float4 v1 = *(device const float4*)(a2 + 4);
1402            float4 v2 = *(device const float4*)(a2 + 8);
1403            float4 v3 = *(device const float4*)(a2 + 12);
1404            float4 v4 = *(device const float4*)(a2 + 16);
1405            float4 v5 = *(device const float4*)(a2 + 20);
1406            float4 v6 = *(device const float4*)(a2 + 24);
1407            float4 v7 = *(device const float4*)(a2 + 28);
1408            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1409                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1410            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1411                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1412            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1413                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1414            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1415                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1416            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1417        }
1418        // Token 3
1419        if (tok_count > 3) {
1420            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1421            float4 v0 = *(device const float4*)(a3);
1422            float4 v1 = *(device const float4*)(a3 + 4);
1423            float4 v2 = *(device const float4*)(a3 + 8);
1424            float4 v3 = *(device const float4*)(a3 + 12);
1425            float4 v4 = *(device const float4*)(a3 + 16);
1426            float4 v5 = *(device const float4*)(a3 + 20);
1427            float4 v6 = *(device const float4*)(a3 + 24);
1428            float4 v7 = *(device const float4*)(a3 + 28);
1429            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1430                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1431            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1432                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1433            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1434                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1435            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1436                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1437            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1438        }
1439    }
1440
1441    // simdgroup reduction
1442    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1443    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1444    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1445    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1446
1447    if (simd_lane == 0) {
1448        device float* o0 = outputs + (tok_base + 0) * rows;
1449        if (row_base     < rows) o0[row_base]     = s00;
1450        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1451        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1452        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1453
1454        if (tok_count > 1) {
1455            device float* o1 = outputs + (tok_base + 1) * rows;
1456            if (row_base     < rows) o1[row_base]     = s10;
1457            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1458            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1459            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1460        }
1461        if (tok_count > 2) {
1462            device float* o2 = outputs + (tok_base + 2) * rows;
1463            if (row_base     < rows) o2[row_base]     = s20;
1464            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1465            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1466            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1467        }
1468        if (tok_count > 3) {
1469            device float* o3 = outputs + (tok_base + 3) * rows;
1470            if (row_base     < rows) o3[row_base]     = s30;
1471            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1472            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1473            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1474        }
1475    }
1476}
1477
1478// ── matmul_q8_mma ──────────────────────────────────────────────────────
1479// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1480// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1481// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1482// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1483//
1484// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1485// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1486// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1487// dequantized into threadgroup memory once per block and reused by all
1488// simdgroups in the tile.
1489//
1490// Assumptions (verified in the dispatch helper, falls back otherwise):
1491//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1492//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1493//   * num_tokens may be any value; partial row at the tile boundary is handled
1494//     via a scratch copy path.
1495constant constexpr uint MMA_TOK_TILE = 16;
1496constant constexpr uint MMA_ROW_TILE = 16;
1497
1498kernel void matmul_q8_mma(
1499    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1500    device const float* inputs   [[buffer(1)]],  // [M, cols]
1501    device float* outputs        [[buffer(2)]],  // [M, rows]
1502    constant uint& num_tokens    [[buffer(3)]],
1503    constant uint& rows          [[buffer(4)]],
1504    constant uint& cols          [[buffer(5)]],
1505    uint2 tgid [[threadgroup_position_in_grid]],
1506    uint tid [[thread_index_in_threadgroup]],
1507    uint simd_id [[simdgroup_index_in_threadgroup]])
1508{
1509    uint row_base = tgid.x * MMA_ROW_TILE;
1510    uint tok_base = tgid.y * MMA_TOK_TILE;
1511    if (row_base >= rows || tok_base >= num_tokens) return;
1512
1513    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1514    threadgroup float w_tile[MMA_ROW_TILE * 32];
1515    threadgroup float t_tile[MMA_TOK_TILE * 32];
1516
1517    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1518    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1519    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1520
1521    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1522
1523    uint blocks_per_row = cols / 32;
1524    uint row_bytes = blocks_per_row * 34;
1525
1526    for (uint blk = 0; blk < blocks_per_row; blk++) {
1527        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1528        // 512 floats / 128 threads = 4 floats per thread.
1529        {
1530            uint base = tid * 4;
1531            for (uint ii = 0; ii < 4; ii++) {
1532                uint idx = base + ii;
1533                uint r = idx / 32;
1534                uint k = idx % 32;
1535                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1536                float sc = float(*(device const half*)rp);
1537                int ival = int(*(device const int8_t*)(rp + 2 + k));
1538                w_tile[r * 32 + k] = float(ival) * sc;
1539            }
1540        }
1541
1542        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1543        {
1544            uint base = tid * 4;
1545            for (uint ii = 0; ii < 4; ii++) {
1546                uint idx = base + ii;
1547                uint m = idx / 32;
1548                uint k = idx % 32;
1549                uint tok = tok_base + m;
1550                t_tile[m * 32 + k] = (tok < num_tokens)
1551                    ? inputs[tok * cols + blk * 32 + k]
1552                    : 0.0f;
1553            }
1554        }
1555
1556        threadgroup_barrier(mem_flags::mem_threadgroup);
1557
1558        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1559        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1560        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1561        // C[m, r] += A[m, k] * B[k, r]
1562        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1563            simdgroup_matrix<float, 8, 8> A, B;
1564            simdgroup_load(A,
1565                t_tile + sg_tok_base * 32 + k_sub * 8,
1566                32,
1567                ulong2(0, 0),
1568                false);
1569            simdgroup_load(B,
1570                w_tile + sg_row_base * 32 + k_sub * 8,
1571                32,
1572                ulong2(0, 0),
1573                true);
1574            simdgroup_multiply_accumulate(C, A, B, C);
1575        }
1576
1577        threadgroup_barrier(mem_flags::mem_threadgroup);
1578    }
1579
1580    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1581    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1582    uint out_tok = tok_base + sg_tok_base;
1583    uint out_row = row_base + sg_row_base;
1584    bool full_tok = (out_tok + 8 <= num_tokens);
1585    if (full_tok) {
1586        // Fast path: entire 8×8 sub-tile is in-bounds.
1587        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1588    } else if (out_tok < num_tokens) {
1589        // Partial row at the last token tile: stage in per-simdgroup scratch
1590        // and scalar-copy the valid rows.
1591        threadgroup float scratch[4 * 64];
1592        simdgroup_store(C, scratch + simd_id * 64, 8);
1593        simdgroup_barrier(mem_flags::mem_threadgroup);
1594        uint lane = tid % 32;
1595        if (lane == 0) {
1596            uint valid = num_tokens - out_tok;  // 1..7
1597            for (uint m = 0; m < valid; m++) {
1598                device float* dst = outputs + (out_tok + m) * rows + out_row;
1599                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1600                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1601                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1602            }
1603        }
1604    }
1605}
1606
1607// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1608// Larger-tile variant of matmul_q8_mma for long-context prefill.
1609//
1610// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1611// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1612// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1613//
1614//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1615//     output sub-tiles (tok, row):
1616//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1617//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1618//
1619// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1620// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1621// kernel — and halves the number of threadgroups vs the 16×16 tile.
1622//
1623// Assumptions (verified in dispatch helper, fallback otherwise):
1624//   * cols % 32 == 0
1625//   * rows % 32 == 0
1626constant constexpr uint MMA32_TOK_TILE = 32;
1627constant constexpr uint MMA32_ROW_TILE = 32;
1628
1629kernel void matmul_q8_mma32(
1630    device const uchar* matrix   [[buffer(0)]],
1631    device const float* inputs   [[buffer(1)]],
1632    device float* outputs        [[buffer(2)]],
1633    constant uint& num_tokens    [[buffer(3)]],
1634    constant uint& rows          [[buffer(4)]],
1635    constant uint& cols          [[buffer(5)]],
1636    uint2 tgid [[threadgroup_position_in_grid]],
1637    uint tid [[thread_index_in_threadgroup]],
1638    uint simd_id [[simdgroup_index_in_threadgroup]])
1639{
1640    uint row_base = tgid.x * MMA32_ROW_TILE;
1641    uint tok_base = tgid.y * MMA32_TOK_TILE;
1642    if (row_base >= rows || tok_base >= num_tokens) return;
1643
1644    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1645    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1646    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1647
1648    uint sg_tok_idx  = simd_id / 2;      // 0..3
1649    uint sg_row_half = simd_id % 2;      // 0..1
1650    uint sg_tok_base = sg_tok_idx * 8;
1651    uint sg_row_base_a = sg_row_half * 16 + 0;
1652    uint sg_row_base_b = sg_row_half * 16 + 8;
1653
1654    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1655    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1656
1657    uint blocks_per_row = cols / 32;
1658    uint row_bytes = blocks_per_row * 34;
1659
1660    for (uint blk = 0; blk < blocks_per_row; blk++) {
1661        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1662        {
1663            uint base = tid * 4;
1664            for (uint ii = 0; ii < 4; ii++) {
1665                uint idx = base + ii;
1666                uint r = idx / 32;
1667                uint k = idx % 32;
1668                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1669                float sc = float(*(device const half*)rp);
1670                int ival = int(*(device const int8_t*)(rp + 2 + k));
1671                w_tile[r * 32 + k] = float(ival) * sc;
1672            }
1673        }
1674
1675        // Cooperative token tile load.
1676        {
1677            uint base = tid * 4;
1678            for (uint ii = 0; ii < 4; ii++) {
1679                uint idx = base + ii;
1680                uint m = idx / 32;
1681                uint k = idx % 32;
1682                uint tok = tok_base + m;
1683                t_tile[m * 32 + k] = (tok < num_tokens)
1684                    ? inputs[tok * cols + blk * 32 + k]
1685                    : 0.0f;
1686            }
1687        }
1688
1689        threadgroup_barrier(mem_flags::mem_threadgroup);
1690
1691        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1692        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1693            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1694            simdgroup_load(A,
1695                t_tile + sg_tok_base * 32 + k_sub * 8,
1696                32,
1697                ulong2(0, 0),
1698                false);
1699            simdgroup_load(B_a,
1700                w_tile + sg_row_base_a * 32 + k_sub * 8,
1701                32,
1702                ulong2(0, 0),
1703                true);
1704            simdgroup_load(B_b,
1705                w_tile + sg_row_base_b * 32 + k_sub * 8,
1706                32,
1707                ulong2(0, 0),
1708                true);
1709            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1710            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1711        }
1712
1713        threadgroup_barrier(mem_flags::mem_threadgroup);
1714    }
1715
1716    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1717    // (verified in dispatch), so full simdgroup_store is safe for the row
1718    // dimension; only the last token tile may be partial.
1719    uint out_tok = tok_base + sg_tok_base;
1720    uint out_row_a = row_base + sg_row_base_a;
1721    uint out_row_b = row_base + sg_row_base_b;
1722    bool full_tok = (out_tok + 8 <= num_tokens);
1723    if (full_tok) {
1724        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1725        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1726    } else if (out_tok < num_tokens) {
1727        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1728        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1729        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1730        simdgroup_barrier(mem_flags::mem_threadgroup);
1731        uint lane = tid % 32;
1732        if (lane == 0) {
1733            uint valid = num_tokens - out_tok;  // 1..7
1734            for (uint m = 0; m < valid; m++) {
1735                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1736                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1737                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1738                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1739                for (uint j = 0; j < 8; j++) {
1740                    dst_a[j] = src_a[j];
1741                    dst_b[j] = src_b[j];
1742                }
1743            }
1744        }
1745    }
1746}
1747
1748// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1749// FP16 threadgroup-tile variant of matmul_q8_mma32.
1750//
1751// Stores dequantized weights and token inputs as `half` in threadgroup
1752// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1753// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1754// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1755// representation preserves the full quantized dynamic range.  Token
1756// activations stay numerically safe because the subsequent
1757// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1758//
1759// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1760// accumulators each.  Primary win vs mma32 is occupancy at moderate
1761// prefill lengths where the GPU is wave-starved.
1762kernel void matmul_q8_mma32_h(
1763    device const uchar* matrix   [[buffer(0)]],
1764    device const float* inputs   [[buffer(1)]],
1765    device float* outputs        [[buffer(2)]],
1766    constant uint& num_tokens    [[buffer(3)]],
1767    constant uint& rows          [[buffer(4)]],
1768    constant uint& cols          [[buffer(5)]],
1769    uint2 tgid [[threadgroup_position_in_grid]],
1770    uint tid [[thread_index_in_threadgroup]],
1771    uint simd_id [[simdgroup_index_in_threadgroup]])
1772{
1773    uint row_base = tgid.x * MMA32_ROW_TILE;
1774    uint tok_base = tgid.y * MMA32_TOK_TILE;
1775    if (row_base >= rows || tok_base >= num_tokens) return;
1776
1777    // 32×32 half tiles — 2 KB each, 4 KB total.
1778    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1779    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1780
1781    uint sg_tok_idx  = simd_id / 2;
1782    uint sg_row_half = simd_id % 2;
1783    uint sg_tok_base = sg_tok_idx * 8;
1784    uint sg_row_base_a = sg_row_half * 16 + 0;
1785    uint sg_row_base_b = sg_row_half * 16 + 8;
1786
1787    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1788    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1789
1790    uint blocks_per_row = cols / 32;
1791    uint row_bytes = blocks_per_row * 34;
1792
1793    for (uint blk = 0; blk < blocks_per_row; blk++) {
1794        // Cooperative weight dequantization to FP16.
1795        {
1796            uint base = tid * 4;
1797            for (uint ii = 0; ii < 4; ii++) {
1798                uint idx = base + ii;
1799                uint r = idx / 32;
1800                uint k = idx % 32;
1801                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1802                float sc = float(*(device const half*)rp);
1803                int ival = int(*(device const int8_t*)(rp + 2 + k));
1804                w_tile[r * 32 + k] = half(float(ival) * sc);
1805            }
1806        }
1807
1808        // Cooperative token tile load (f32 → f16 narrowing).
1809        {
1810            uint base = tid * 4;
1811            for (uint ii = 0; ii < 4; ii++) {
1812                uint idx = base + ii;
1813                uint m = idx / 32;
1814                uint k = idx % 32;
1815                uint tok = tok_base + m;
1816                t_tile[m * 32 + k] = (tok < num_tokens)
1817                    ? half(inputs[tok * cols + blk * 32 + k])
1818                    : half(0);
1819            }
1820        }
1821
1822        threadgroup_barrier(mem_flags::mem_threadgroup);
1823
1824        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1825            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1826            simdgroup_load(A,
1827                t_tile + sg_tok_base * 32 + k_sub * 8,
1828                32,
1829                ulong2(0, 0),
1830                false);
1831            simdgroup_load(B_a,
1832                w_tile + sg_row_base_a * 32 + k_sub * 8,
1833                32,
1834                ulong2(0, 0),
1835                true);
1836            simdgroup_load(B_b,
1837                w_tile + sg_row_base_b * 32 + k_sub * 8,
1838                32,
1839                ulong2(0, 0),
1840                true);
1841            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1842            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1843        }
1844
1845        threadgroup_barrier(mem_flags::mem_threadgroup);
1846    }
1847
1848    uint out_tok = tok_base + sg_tok_base;
1849    uint out_row_a = row_base + sg_row_base_a;
1850    uint out_row_b = row_base + sg_row_base_b;
1851    bool full_tok = (out_tok + 8 <= num_tokens);
1852    if (full_tok) {
1853        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1854        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1855    } else if (out_tok < num_tokens) {
1856        threadgroup float scratch[8 * 2 * 64];
1857        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1858        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1859        simdgroup_barrier(mem_flags::mem_threadgroup);
1860        uint lane = tid % 32;
1861        if (lane == 0) {
1862            uint valid = num_tokens - out_tok;
1863            for (uint m = 0; m < valid; m++) {
1864                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1865                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1866                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1867                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1868                for (uint j = 0; j < 8; j++) {
1869                    dst_a[j] = src_a[j];
1870                    dst_b[j] = src_b[j];
1871                }
1872            }
1873        }
1874    }
1875}
1876
1877// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1878// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1879//
1880// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1881// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1882//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1883//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1884// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1885//
1886// Per K_sub iteration: load two A fragments and two B fragments, then run
1887// **four** MMA instructions reusing A_top with both B's and A_bot with
1888// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1889// 2-accumulator kernel and halves the thread count per threadgroup (128
1890// threads), which often improves occupancy on Apple GPUs where the
1891// concurrent-thread budget is the tighter limit than shared-memory size.
1892kernel void matmul_q8_mma32_h4(
1893    device const uchar* matrix   [[buffer(0)]],
1894    device const float* inputs   [[buffer(1)]],
1895    device float* outputs        [[buffer(2)]],
1896    constant uint& num_tokens    [[buffer(3)]],
1897    constant uint& rows          [[buffer(4)]],
1898    constant uint& cols          [[buffer(5)]],
1899    uint2 tgid [[threadgroup_position_in_grid]],
1900    uint tid [[thread_index_in_threadgroup]],
1901    uint simd_id [[simdgroup_index_in_threadgroup]])
1902{
1903    uint row_base = tgid.x * MMA32_ROW_TILE;
1904    uint tok_base = tgid.y * MMA32_TOK_TILE;
1905    if (row_base >= rows || tok_base >= num_tokens) return;
1906
1907    // 32×32 FP16 tiles, 4 KB total.
1908    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1909    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1910
1911    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1912    uint sg_tok_q = simd_id / 2;   // 0..1
1913    uint sg_row_q = simd_id % 2;   // 0..1
1914    uint sg_tok_base = sg_tok_q * 16;
1915    uint sg_row_base = sg_row_q * 16;
1916
1917    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1918    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1919    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1920    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1921
1922    uint blocks_per_row = cols / 32;
1923    uint row_bytes = blocks_per_row * 34;
1924
1925    for (uint blk = 0; blk < blocks_per_row; blk++) {
1926        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1927        {
1928            uint base = tid * 8;
1929            for (uint ii = 0; ii < 8; ii++) {
1930                uint idx = base + ii;
1931                uint r = idx / 32;
1932                uint k = idx % 32;
1933                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1934                float sc = float(*(device const half*)rp);
1935                int ival = int(*(device const int8_t*)(rp + 2 + k));
1936                w_tile[r * 32 + k] = half(float(ival) * sc);
1937            }
1938        }
1939
1940        // Cooperative token tile load.
1941        {
1942            uint base = tid * 8;
1943            for (uint ii = 0; ii < 8; ii++) {
1944                uint idx = base + ii;
1945                uint m = idx / 32;
1946                uint k = idx % 32;
1947                uint tok = tok_base + m;
1948                t_tile[m * 32 + k] = (tok < num_tokens)
1949                    ? half(inputs[tok * cols + blk * 32 + k])
1950                    : half(0);
1951            }
1952        }
1953
1954        threadgroup_barrier(mem_flags::mem_threadgroup);
1955
1956        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1957        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1958            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1959            simdgroup_load(A_top,
1960                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1961                32,
1962                ulong2(0, 0),
1963                false);
1964            simdgroup_load(A_bot,
1965                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1966                32,
1967                ulong2(0, 0),
1968                false);
1969            simdgroup_load(B_lo,
1970                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1971                32,
1972                ulong2(0, 0),
1973                true);
1974            simdgroup_load(B_hi,
1975                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1976                32,
1977                ulong2(0, 0),
1978                true);
1979            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1980            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1981            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1982            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1983        }
1984
1985        threadgroup_barrier(mem_flags::mem_threadgroup);
1986    }
1987
1988    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
1989    uint out_tok_top = tok_base + sg_tok_base + 0;
1990    uint out_tok_bot = tok_base + sg_tok_base + 8;
1991    uint out_row_lo  = row_base + sg_row_base + 0;
1992    uint out_row_hi  = row_base + sg_row_base + 8;
1993    bool full = (out_tok_bot + 8 <= num_tokens);
1994    if (full) {
1995        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1996        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1997        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1998        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1999    } else {
2000        // Partial-token fallback via per-simdgroup scratch.
2001        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
2002        uint sg_base = simd_id * 256;
2003        simdgroup_store(C_00, scratch + sg_base +   0, 8);
2004        simdgroup_store(C_01, scratch + sg_base +  64, 8);
2005        simdgroup_store(C_10, scratch + sg_base + 128, 8);
2006        simdgroup_store(C_11, scratch + sg_base + 192, 8);
2007        simdgroup_barrier(mem_flags::mem_threadgroup);
2008        uint lane = tid % 32;
2009        if (lane == 0) {
2010            for (uint m = 0; m < 8; m++) {
2011                uint t_top = out_tok_top + m;
2012                if (t_top < num_tokens) {
2013                    device float* dst0 = outputs + t_top * rows + out_row_lo;
2014                    device float* dst1 = outputs + t_top * rows + out_row_hi;
2015                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
2016                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
2017                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2018                }
2019                uint t_bot = out_tok_bot + m;
2020                if (t_bot < num_tokens) {
2021                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2022                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2023                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2024                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2025                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2026                }
2027            }
2028        }
2029    }
2030}
2031
2032// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2033// All-half MMA variant of matmul_q8_mma32_h4.
2034//
2035// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2036// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2037// rate (dual-issue FMA), so if Q8_0 precision holds through half
2038// accumulation this kernel can double the effective FLOP throughput on
2039// matmul-bound prefill.
2040//
2041// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2042// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2043// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2044// precision on extreme values, but the inputs are already quantized so the
2045// per-product error floor is higher than the half-precision rounding error.
2046// We verify correctness on 135M / 1B / 3B before enabling.
2047kernel void matmul_q8_mma32_hh4(
2048    device const uchar* matrix   [[buffer(0)]],
2049    device const float* inputs   [[buffer(1)]],
2050    device float* outputs        [[buffer(2)]],
2051    constant uint& num_tokens    [[buffer(3)]],
2052    constant uint& rows          [[buffer(4)]],
2053    constant uint& cols          [[buffer(5)]],
2054    uint2 tgid [[threadgroup_position_in_grid]],
2055    uint tid [[thread_index_in_threadgroup]],
2056    uint simd_id [[simdgroup_index_in_threadgroup]])
2057{
2058    uint row_base = tgid.x * MMA32_ROW_TILE;
2059    uint tok_base = tgid.y * MMA32_TOK_TILE;
2060    if (row_base >= rows || tok_base >= num_tokens) return;
2061
2062    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2063    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2064
2065    uint sg_tok_q = simd_id / 2;
2066    uint sg_row_q = simd_id % 2;
2067    uint sg_tok_base = sg_tok_q * 16;
2068    uint sg_row_base = sg_row_q * 16;
2069
2070    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2071    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2072    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2073    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2074
2075    uint blocks_per_row = cols / 32;
2076    uint row_bytes = blocks_per_row * 34;
2077
2078    for (uint blk = 0; blk < blocks_per_row; blk++) {
2079        {
2080            uint base = tid * 8;
2081            for (uint ii = 0; ii < 8; ii++) {
2082                uint idx = base + ii;
2083                uint r = idx / 32;
2084                uint k = idx % 32;
2085                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2086                float sc = float(*(device const half*)rp);
2087                int ival = int(*(device const int8_t*)(rp + 2 + k));
2088                w_tile[r * 32 + k] = half(float(ival) * sc);
2089            }
2090        }
2091        {
2092            uint base = tid * 8;
2093            for (uint ii = 0; ii < 8; ii++) {
2094                uint idx = base + ii;
2095                uint m = idx / 32;
2096                uint k = idx % 32;
2097                uint tok = tok_base + m;
2098                t_tile[m * 32 + k] = (tok < num_tokens)
2099                    ? half(inputs[tok * cols + blk * 32 + k])
2100                    : half(0);
2101            }
2102        }
2103        threadgroup_barrier(mem_flags::mem_threadgroup);
2104
2105        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2106            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2107            simdgroup_load(A_top,
2108                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2109                32, ulong2(0, 0), false);
2110            simdgroup_load(A_bot,
2111                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2112                32, ulong2(0, 0), false);
2113            simdgroup_load(B_lo,
2114                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2115                32, ulong2(0, 0), true);
2116            simdgroup_load(B_hi,
2117                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2118                32, ulong2(0, 0), true);
2119            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2120            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2121            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2122            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2123        }
2124
2125        threadgroup_barrier(mem_flags::mem_threadgroup);
2126    }
2127
2128    // Store half accumulators via scratch (must widen to f32 for device output).
2129    uint out_tok_top = tok_base + sg_tok_base + 0;
2130    uint out_tok_bot = tok_base + sg_tok_base + 8;
2131    uint out_row_lo  = row_base + sg_row_base + 0;
2132    uint out_row_hi  = row_base + sg_row_base + 8;
2133
2134    threadgroup half scratch[4 * 4 * 64];
2135    uint sg_base = simd_id * 256;
2136    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2137    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2138    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2139    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2140    simdgroup_barrier(mem_flags::mem_threadgroup);
2141    uint lane = tid % 32;
2142    if (lane == 0) {
2143        for (uint m = 0; m < 8; m++) {
2144            uint t_top = out_tok_top + m;
2145            if (t_top < num_tokens) {
2146                device float* dst0 = outputs + t_top * rows + out_row_lo;
2147                device float* dst1 = outputs + t_top * rows + out_row_hi;
2148                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2149                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2150                for (uint j = 0; j < 8; j++) {
2151                    dst0[j] = float(src0[j]);
2152                    dst1[j] = float(src1[j]);
2153                }
2154            }
2155            uint t_bot = out_tok_bot + m;
2156            if (t_bot < num_tokens) {
2157                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2158                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2159                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2160                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2161                for (uint j = 0; j < 8; j++) {
2162                    dst2[j] = float(src2[j]);
2163                    dst3[j] = float(src3[j]);
2164                }
2165            }
2166        }
2167    }
2168}
2169
2170// ── add_bias_batch ─────────────────────────────────────────────────────
2171// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2172// Used for Qwen2 QKV bias after the fused qkv matmul.
2173//     out[token, i] += bias[i]    for i in 0..rows, token in 0..num_tokens
2174kernel void add_bias_batch(
2175    device float* out            [[buffer(0)]],  // [num_tokens, rows]
2176    device const float* bias     [[buffer(1)]],  // [rows]
2177    constant uint& num_tokens    [[buffer(2)]],
2178    constant uint& rows          [[buffer(3)]],
2179    uint id [[thread_position_in_grid]])
2180{
2181    uint total = num_tokens * rows;
2182    if (id >= total) return;
2183    uint i = id % rows;
2184    out[id] += bias[i];
2185}
2186
2187// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2188// Batched Q4_0 matrix-vector multiply for M input vectors.
2189// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2190kernel void matmul_vec_q4_batch(
2191    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2192    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2193    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2194    constant uint& num_tokens    [[buffer(3)]],  // M
2195    constant uint& rows          [[buffer(4)]],
2196    constant uint& cols          [[buffer(5)]],
2197    uint tgid [[threadgroup_position_in_grid]],
2198    uint tid [[thread_index_in_threadgroup]],
2199    uint simd_lane [[thread_index_in_simdgroup]],
2200    uint simd_id [[simdgroup_index_in_threadgroup]])
2201{
2202    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2203    uint token = tgid / row_tgs;
2204    uint tg_in_token = tgid % row_tgs;
2205    if (token >= num_tokens) return;
2206
2207    threadgroup float vec_tile[VEC_TILE_SIZE];
2208    device const float* input = inputs + token * cols;
2209    for (uint i = tid; i < cols; i += 256) {
2210        vec_tile[i] = input[i];
2211    }
2212    threadgroup_barrier(mem_flags::mem_threadgroup);
2213
2214    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2215    if (row_base >= rows) return;
2216
2217    uint blocks_per_row = cols / 32;
2218    uint row_bytes = blocks_per_row * 18;
2219
2220    device const uchar* r0 = matrix + row_base * row_bytes;
2221    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2222    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2223    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2224
2225    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2226
2227    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2228        uint bb = blk * 18;
2229        uint vb = blk * 32;
2230
2231        float sc0 = float(*(device const half*)(r0 + bb));
2232        float sc1 = float(*(device const half*)(r1 + bb));
2233        float sc2 = float(*(device const half*)(r2 + bb));
2234        float sc3 = float(*(device const half*)(r3 + bb));
2235
2236        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2237        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2238        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2239        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2240
2241        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2242        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2243        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2244        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2245        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2246        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2247        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2248        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2249
2250        float bd0=0, bd1=0, bd2=0, bd3=0;
2251        uchar4 b;
2252
2253        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2254        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2255        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2256        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2257
2258        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2259        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2260        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2261        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2262
2263        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2264        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2265        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2266        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2267
2268        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2269        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2270        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2271        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2272
2273        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2274    }
2275
2276    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2277    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2278
2279    device float* output = outputs + token * rows;
2280    if (simd_lane == 0) {
2281        if (row_base     < rows) output[row_base]     = sum0;
2282        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2283        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2284        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2285    }
2286}
2287
2288// ── copy_kv_batch ─────────────────────────────────────────────────────
2289// Copy K or V from a strided batch QKV buffer to the KV cache.
2290// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2291// dst layout: contiguous [max_seq, kv_dim] cache.
2292kernel void copy_kv_batch(
2293    device const float* src  [[buffer(0)]],  // batch QKV buffer (f32)
2294    device half* dst         [[buffer(1)]],  // KV cache (f16)
2295    constant uint& M         [[buffer(2)]],  // num tokens in batch
2296    constant uint& kv_dim    [[buffer(3)]],  // elements per KV vector
2297    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2298    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2299    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2300    uint id [[thread_position_in_grid]])
2301{
2302    uint total = M * kv_dim;
2303    if (id >= total) return;
2304    uint token = id / kv_dim;
2305    uint d = id % kv_dim;
2306    uint dst_off = (base_pos + token) * kv_dim + d;
2307    uint src_off = token * src_stride + src_offset + d;
2308    dst[dst_off] = half(src[src_off]);
2309}
2310
2311// ── attention_batch ───────────────────────────────────────────────────
2312// Batched causal attention for prefill. Processes M tokens in one dispatch.
2313// Each threadgroup handles one (token, head) pair.
2314// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2315// Causal masking: token i can only attend to positions 0..base_pos+i.
2316kernel void attention_batch(
2317    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2318    device const half*  k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim] f16
2319    device const half*  v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim] f16
2320    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2321    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2322    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2323    constant uint& num_heads         [[buffer(6)]],
2324    constant uint& num_kv_heads      [[buffer(7)]],
2325    constant uint& head_dim          [[buffer(8)]],
2326    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2327    uint tgid [[threadgroup_position_in_grid]],
2328    uint tid [[thread_index_in_threadgroup]],
2329    uint simd_lane [[thread_index_in_simdgroup]],
2330    uint simd_id [[simdgroup_index_in_threadgroup]])
2331{
2332    // Grid: M * num_heads threadgroups
2333    uint token_idx = tgid / num_heads;
2334    uint head = tgid % num_heads;
2335    if (token_idx >= M) return;
2336
2337    uint kv_head = head / (num_heads / num_kv_heads);
2338    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2339
2340    // Q offset uses strided layout (from batch QKV buffer)
2341    uint q_off = token_idx * q_stride + head * head_dim;
2342    // Output is contiguous [M, num_heads * head_dim]
2343    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2344
2345    // Shared memory for attention scores — sized to the effective max_seq_len
2346    // (4096 for all supported models) so long-context attention doesn't overflow.
2347    threadgroup float scores[ATTN_SCORES_SIZE];
2348
2349    // Step 1: Q * K^T with simdgroup reduction
2350    for (uint s = simd_id; s < seq_len; s += 8) {
2351        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2352        float dot = 0.0;
2353        for (uint d = simd_lane; d < head_dim; d += 32) {
2354            dot += q_batch[q_off + d] * k_cache[k_off + d];
2355        }
2356        dot = simd_sum(dot);
2357        if (simd_lane == 0) {
2358            scores[s] = dot * fast::rsqrt(float(head_dim));
2359        }
2360    }
2361    threadgroup_barrier(mem_flags::mem_threadgroup);
2362
2363    // Step 2: Softmax (cooperative)
2364    float local_max = -INFINITY;
2365    for (uint s = tid; s < seq_len; s += 256) {
2366        local_max = max(local_max, scores[s]);
2367    }
2368    local_max = simd_max(local_max);
2369    threadgroup float shared_max[8];
2370    if (simd_lane == 0) shared_max[simd_id] = local_max;
2371    threadgroup_barrier(mem_flags::mem_threadgroup);
2372    if (tid == 0) {
2373        float m = shared_max[0];
2374        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2375        shared_max[0] = m;
2376    }
2377    threadgroup_barrier(mem_flags::mem_threadgroup);
2378    float max_val = shared_max[0];
2379
2380    float local_sum = 0.0;
2381    for (uint s = tid; s < seq_len; s += 256) {
2382        scores[s] = fast::exp(scores[s] - max_val);
2383        local_sum += scores[s];
2384    }
2385    local_sum = simd_sum(local_sum);
2386    threadgroup float shared_sum[8];
2387    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2388    threadgroup_barrier(mem_flags::mem_threadgroup);
2389    if (tid == 0) {
2390        float total = 0.0;
2391        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2392        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2393    }
2394    threadgroup_barrier(mem_flags::mem_threadgroup);
2395    float inv_sum = shared_sum[0];
2396    for (uint s = tid; s < seq_len; s += 256) {
2397        scores[s] *= inv_sum;
2398    }
2399    threadgroup_barrier(mem_flags::mem_threadgroup);
2400
2401    // Step 3: scores * V using float4 vectorized loads
2402    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2403    // This is much better than the scalar version where only 64 of 256 threads are active.
2404    uint v_stride = num_kv_heads * head_dim;
2405    uint head_dim4 = head_dim / 4;
2406    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2407        uint d = d4 * 4;
2408        float4 acc = float4(0.0);
2409        uint v_base = kv_head * head_dim + d;
2410        uint seq_len4 = seq_len & ~3u;
2411        for (uint s = 0; s < seq_len4; s += 4) {
2412            float sc0 = scores[s];
2413            float sc1 = scores[s + 1];
2414            float sc2 = scores[s + 2];
2415            float sc3 = scores[s + 3];
2416            acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
2417                 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
2418                 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
2419                 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
2420        }
2421        for (uint s = seq_len4; s < seq_len; s++) {
2422            acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
2423        }
2424        *(device float4*)(output_batch + out_off + d) = acc;
2425    }
2426    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2427    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2428        float acc = 0.0;
2429        uint v_base = kv_head * head_dim + d;
2430        for (uint s = 0; s < seq_len; s++) {
2431            acc += scores[s] * v_cache[s * v_stride + v_base];
2432        }
2433        output_batch[out_off + d] = acc;
2434    }
2435}
2436
2437// ── attention_flash_batch ─────────────────────────────────────────────
2438// Streaming attention with online softmax.  Same grid as attention_batch
2439// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2440// matrix is never materialized.  K/V positions are processed in a tile of
2441// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2442// the standard flash-attention recurrence:
2443//
2444//     m_new   = max(m_old, tile_max)
2445//     alpha   = exp(m_old - m_new)
2446//     l_new   = alpha * l_old + sum(exp(S - m_new))
2447//     O_new   = alpha * O_old + sum(exp(S - m_new) * V)
2448//     O_final = O / l
2449//
2450// This removes the `scores[2048]` cap in attention_batch (which silently
2451// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2452// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2453//
2454// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2455constant constexpr uint FLASH_K_TILE = 32;
2456constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2457
2458kernel void attention_flash_batch(
2459    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2460    device const half*  k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim] f16
2461    device const half*  v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim] f16
2462    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2463    constant uint& M                 [[buffer(4)]],
2464    constant uint& base_pos          [[buffer(5)]],
2465    constant uint& num_heads         [[buffer(6)]],
2466    constant uint& num_kv_heads      [[buffer(7)]],
2467    constant uint& head_dim          [[buffer(8)]],
2468    constant uint& q_stride          [[buffer(9)]],
2469    uint tgid [[threadgroup_position_in_grid]],
2470    uint tid [[thread_index_in_threadgroup]],
2471    uint simd_lane [[thread_index_in_simdgroup]],
2472    uint simd_id [[simdgroup_index_in_threadgroup]])
2473{
2474    uint token_idx = tgid / num_heads;
2475    uint head = tgid % num_heads;
2476    if (token_idx >= M) return;
2477
2478    uint kv_head = head / (num_heads / num_kv_heads);
2479    uint seq_len = base_pos + token_idx + 1;  // causal: attend to [0, base_pos + token_idx]
2480    uint q_off = token_idx * q_stride + head * head_dim;
2481    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2482
2483    // Threadgroup state:
2484    //   q_sh:      Q vector for this (token, head), loaded once
2485    //   o_sh:      running output vector, updated each K tile
2486    //   scores_sh: scores for the current K tile only
2487    //   stats:     [running max, running sum]  (see flash-attention recurrence)
2488    threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2489    threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2490    threadgroup float scores_sh[FLASH_K_TILE];
2491    threadgroup float stats[2];
2492    threadgroup float sg_scratch[8];  // simdgroup-level reduction buffer
2493
2494    // --- Load Q (one row) and zero the running O ---
2495    for (uint d = tid; d < head_dim; d += 256) {
2496        q_sh[d] = q_batch[q_off + d];
2497        o_sh[d] = 0.0f;
2498    }
2499    if (tid == 0) {
2500        stats[0] = -INFINITY;
2501        stats[1] = 0.0f;
2502    }
2503    threadgroup_barrier(mem_flags::mem_threadgroup);
2504
2505    float scale = fast::rsqrt(float(head_dim));
2506    uint v_stride = num_kv_heads * head_dim;
2507    uint v_base = kv_head * head_dim;
2508
2509    // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2510    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2511        uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2512
2513        // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2514        // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2515        for (uint ti = simd_id; ti < tile_n; ti += 8) {
2516            uint k_off = (kv_base + ti) * v_stride + v_base;  // same layout as V stride
2517            float dot = 0.0f;
2518            for (uint d = simd_lane; d < head_dim; d += 32) {
2519                dot += q_sh[d] * k_cache[k_off + d];
2520            }
2521            dot = simd_sum(dot);
2522            if (simd_lane == 0) {
2523                scores_sh[ti] = dot * scale;
2524            }
2525        }
2526        threadgroup_barrier(mem_flags::mem_threadgroup);
2527
2528        // [2] Tile max via cooperative reduction.
2529        float local_max = -INFINITY;
2530        for (uint s = tid; s < tile_n; s += 256) {
2531            local_max = max(local_max, scores_sh[s]);
2532        }
2533        local_max = simd_max(local_max);
2534        if (simd_lane == 0) {
2535            sg_scratch[simd_id] = local_max;
2536        }
2537        threadgroup_barrier(mem_flags::mem_threadgroup);
2538        // [3] Merge with running max, compute alpha, rescale running l.
2539        float m_new;
2540        float alpha;
2541        if (tid == 0) {
2542            float tile_max = sg_scratch[0];
2543            for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2544            float m_old = stats[0];
2545            m_new = max(m_old, tile_max);
2546            // First iteration: m_old = -inf → alpha = 0 (reset O).
2547            alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2548            stats[0] = m_new;
2549            stats[1] *= alpha;
2550            // Broadcast via sg_scratch.
2551            sg_scratch[0] = alpha;
2552            sg_scratch[1] = m_new;
2553        }
2554        threadgroup_barrier(mem_flags::mem_threadgroup);
2555        alpha = sg_scratch[0];
2556        m_new = sg_scratch[1];
2557
2558        // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2559        for (uint d = tid; d < head_dim; d += 256) {
2560            o_sh[d] *= alpha;
2561        }
2562        for (uint s = tid; s < tile_n; s += 256) {
2563            scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2564        }
2565        threadgroup_barrier(mem_flags::mem_threadgroup);
2566
2567        // [5] Tile sum → update running l.
2568        float local_sum = 0.0f;
2569        for (uint s = tid; s < tile_n; s += 256) {
2570            local_sum += scores_sh[s];
2571        }
2572        local_sum = simd_sum(local_sum);
2573        if (simd_lane == 0) {
2574            sg_scratch[simd_id] = local_sum;
2575        }
2576        threadgroup_barrier(mem_flags::mem_threadgroup);
2577        if (tid == 0) {
2578            float tile_sum = 0.0f;
2579            for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2580            stats[1] += tile_sum;
2581        }
2582        threadgroup_barrier(mem_flags::mem_threadgroup);
2583
2584        // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2585        for (uint d = tid; d < head_dim; d += 256) {
2586            float acc = 0.0f;
2587            for (uint s = 0; s < tile_n; s++) {
2588                acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2589            }
2590            o_sh[d] += acc;
2591        }
2592        threadgroup_barrier(mem_flags::mem_threadgroup);
2593    }
2594
2595    // --- Normalize and write output ---
2596    float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2597    for (uint d = tid; d < head_dim; d += 256) {
2598        output_batch[out_off + d] = o_sh[d] * inv_l;
2599    }
2600}
2601
2602// ── attention_mma_flash_batch ─────────────────────────────────────────
2603// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2604// both Q·K^T and P·V.  Processes Q_BLOCK=8 tokens of one head per
2605// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2606// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2607// arithmetic.
2608//
2609// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2610// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2611// to attention_batch / attention_flash_batch otherwise.
2612//
2613// Online softmax recurrence is identical to attention_flash_batch but
2614// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2615constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2616constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2617constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2618
2619kernel void attention_mma_flash_batch(
2620    device const float* q_batch      [[buffer(0)]],
2621    device const half*  k_cache      [[buffer(1)]],
2622    device const half*  v_cache      [[buffer(2)]],
2623    device float* output_batch       [[buffer(3)]],
2624    constant uint& M                 [[buffer(4)]],
2625    constant uint& base_pos          [[buffer(5)]],
2626    constant uint& num_heads         [[buffer(6)]],
2627    constant uint& num_kv_heads      [[buffer(7)]],
2628    constant uint& head_dim          [[buffer(8)]],
2629    constant uint& q_stride          [[buffer(9)]],
2630    uint2 tgid [[threadgroup_position_in_grid]],
2631    uint tid [[thread_index_in_threadgroup]],
2632    uint simd_lane [[thread_index_in_simdgroup]],
2633    uint simd_id [[simdgroup_index_in_threadgroup]])
2634{
2635    uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2636    uint head = tgid.y;
2637    if (q_block_start >= M) return;
2638    uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2639
2640    uint kv_head = head / (num_heads / num_kv_heads);
2641    // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2642    // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2643    uint seq_len = base_pos + q_block_start + q_valid;
2644    float scale = fast::rsqrt(float(head_dim));
2645
2646    uint kv_stride = num_kv_heads * head_dim;
2647    uint kv_base_off = kv_head * head_dim;
2648
2649    // ── Threadgroup memory ──
2650    // q_sh:  [Q_BLOCK, head_dim] half  — Q tile, loaded once
2651    // k_sh:  [K_BLOCK, head_dim] half  — K tile, refreshed per kv_base iter
2652    // v_sh:  [K_BLOCK, head_dim] half  — V tile, refreshed per kv_base iter
2653    // s_sh:  [Q_BLOCK, K_BLOCK] float  — raw Q·K^T scores, then scaled+masked
2654    // p_sh:  [Q_BLOCK, K_BLOCK] half   — softmax probabilities (for P·V MMA)
2655    // o_sh:  [Q_BLOCK, head_dim] float — running output accumulator
2656    // m_sh:  [Q_BLOCK] float           — running max per Q row
2657    // l_sh:  [Q_BLOCK] float           — running softmax denominator per Q row
2658    // scratch: 4*Q_BLOCK floats        — per-row reduction scratch
2659    threadgroup half  q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2660    threadgroup half  k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2661    threadgroup half  v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2662    threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2663    threadgroup half  p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2664    threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2665    threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2666    threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2667    threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2668
2669    // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2670    uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2671    for (uint i = tid; i < qblock_elems; i += 128) {
2672        uint q = i / head_dim;
2673        uint d = i % head_dim;
2674        if (q < q_valid) {
2675            uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2676            q_sh[q * head_dim + d] = half(q_batch[q_off]);
2677        } else {
2678            q_sh[q * head_dim + d] = half(0);
2679        }
2680        o_sh[q * head_dim + d] = 0.0f;
2681    }
2682    if (tid < FLASH_MMA_Q_BLOCK) {
2683        m_sh[tid] = -INFINITY;
2684        l_sh[tid] = 0.0f;
2685    }
2686    threadgroup_barrier(mem_flags::mem_threadgroup);
2687
2688    // ── Stream K/V in K_BLOCK chunks ──
2689    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2690        uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2691
2692        // Load K and V tile into TG memory (as half).
2693        uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2694        for (uint i = tid; i < kv_tile_elems; i += 128) {
2695            uint k_pos = i / head_dim;
2696            uint d = i % head_dim;
2697            if (k_pos < tile_n) {
2698                uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2699                k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2700                v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2701            } else {
2702                k_sh[k_pos * head_dim + d] = half(0);
2703                v_sh[k_pos * head_dim + d] = half(0);
2704            }
2705        }
2706        threadgroup_barrier(mem_flags::mem_threadgroup);
2707
2708        // ── Phase 1: S = Q @ K^T via MMA ──
2709        // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2710        // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2711        // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2712        {
2713            simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2714            uint dim_chunks = head_dim / 8;
2715            for (uint dc = 0; dc < dim_chunks; dc++) {
2716                simdgroup_matrix<half, 8, 8> A, B;
2717                // A = Q[0:8, dc*8 : dc*8+8]  (rows of Q, no transpose)
2718                simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2719                // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2720                // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2721                // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2722                // transpose=true, which places it in the register as K^T of that sub-block.
2723                simdgroup_load(B,
2724                    k_sh + (simd_id * 8) * head_dim + dc * 8,
2725                    head_dim, ulong2(0, 0), true);
2726                simdgroup_multiply_accumulate(C, A, B, C);
2727            }
2728            // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2729            simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2730        }
2731        threadgroup_barrier(mem_flags::mem_threadgroup);
2732
2733        // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2734        // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2735        uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2736        for (uint i = tid; i < s_elems; i += 128) {
2737            uint q = i / FLASH_MMA_K_BLOCK;
2738            uint k = i % FLASH_MMA_K_BLOCK;
2739            uint global_q = q_block_start + q;
2740            uint global_kv = kv_base + k;
2741            bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2742            s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2743        }
2744        threadgroup_barrier(mem_flags::mem_threadgroup);
2745
2746        // ── Phase 2b: per-row max via simdgroup reduction ──
2747        // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2748        // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2749        {
2750            uint row_base = simd_id * 2;
2751            for (uint qr = 0; qr < 2; qr++) {
2752                uint q = row_base + qr;
2753                float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2754                float row_max = simd_max(my);
2755                if (simd_lane == 0) {
2756                    scratch[q] = row_max;  // tile_max[q]
2757                }
2758            }
2759        }
2760        threadgroup_barrier(mem_flags::mem_threadgroup);
2761
2762        // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2763        if (tid < FLASH_MMA_Q_BLOCK) {
2764            uint q = tid;
2765            float m_old = m_sh[q];
2766            float tile_max = scratch[q];
2767            float m_new = max(m_old, tile_max);
2768            float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2769            m_sh[q] = m_new;
2770            l_sh[q] = l_sh[q] * alpha;
2771            // scratch[q]               = m_new   (for phase 2d)
2772            // scratch[Q_BLOCK + q]     = alpha   (for phase 3)
2773            scratch[q] = m_new;
2774            scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2775        }
2776        threadgroup_barrier(mem_flags::mem_threadgroup);
2777
2778        // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2779        {
2780            uint row_base = simd_id * 2;
2781            for (uint qr = 0; qr < 2; qr++) {
2782                uint q = row_base + qr;
2783                float m_new = scratch[q];
2784                float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2785                p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2786                float row_sum = simd_sum(p);
2787                if (simd_lane == 0) {
2788                    scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2789                }
2790            }
2791        }
2792        threadgroup_barrier(mem_flags::mem_threadgroup);
2793
2794        // ── Phase 2e: l_sh += tile_sum ──
2795        if (tid < FLASH_MMA_Q_BLOCK) {
2796            uint q = tid;
2797            l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2798        }
2799
2800        // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2801        threadgroup_barrier(mem_flags::mem_threadgroup);
2802        for (uint i = tid; i < qblock_elems; i += 128) {
2803            uint q = i / head_dim;
2804            float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2805            o_sh[i] *= alpha;
2806        }
2807        threadgroup_barrier(mem_flags::mem_threadgroup);
2808
2809        // ── Phase 4: O += P @ V via MMA ──
2810        // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2811        // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2812        // For head_dim=64: 16 dims/sg = 2 tiles.  head_dim=128: 32 dims/sg = 4 tiles.
2813        {
2814            uint dims_per_sg = head_dim / 4;       // 16 or 32
2815            uint tiles_per_sg = dims_per_sg / 8;   // 2 or 4
2816            uint sg_d_base = simd_id * dims_per_sg;
2817            for (uint t = 0; t < tiles_per_sg; t++) {
2818                uint d_base = sg_d_base + t * 8;
2819                simdgroup_matrix<float, 8, 8> O_acc;
2820                simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2821                uint k_chunks = FLASH_MMA_K_BLOCK / 8;  // 4
2822                for (uint kc = 0; kc < k_chunks; kc++) {
2823                    simdgroup_matrix<half, 8, 8> A, B;
2824                    simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2825                                   ulong2(0, 0), false);
2826                    simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2827                                   ulong2(0, 0), false);
2828                    simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2829                }
2830                simdgroup_store(O_acc, o_sh + d_base, head_dim);
2831            }
2832        }
2833        threadgroup_barrier(mem_flags::mem_threadgroup);
2834    }
2835
2836    // ── Finalize: O /= l, write to output ──
2837    for (uint i = tid; i < qblock_elems; i += 128) {
2838        uint q = i / head_dim;
2839        uint d = i % head_dim;
2840        if (q < q_valid) {
2841            float l = l_sh[q];
2842            float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2843            uint token_idx = q_block_start + q;
2844            uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2845            output_batch[out_off] = o_sh[i] * inv_l;
2846        }
2847    }
2848}
2849
2850// ── rope_qk_batch ─────────────────────────────────────────────────────
2851// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2852// launch + memory barrier per layer. Both Q and K live in the same
2853// qkv_data buffer at different offsets within each token's row.
2854// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2855kernel void rope_qk_batch(
2856    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2857    constant uint& M                 [[buffer(1)]],   // num tokens
2858    constant uint& base_pos          [[buffer(2)]],   // starting position
2859    constant uint& num_q_heads       [[buffer(3)]],
2860    constant uint& num_kv_heads      [[buffer(4)]],
2861    constant uint& head_dim          [[buffer(5)]],
2862    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2863    constant float& theta            [[buffer(7)]],
2864    uint id [[thread_position_in_grid]])
2865{
2866    uint half_dim = head_dim / 2;
2867    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2868    uint token = id / total_pairs;
2869    uint pair = id % total_pairs;
2870    if (token >= M) return;
2871
2872    uint pos = base_pos + token;
2873    uint q_pairs = num_q_heads * half_dim;
2874
2875    uint h, i, offset;
2876    if (pair < q_pairs) {
2877        // Q head
2878        h = pair / half_dim;
2879        i = pair % half_dim;
2880        offset = token * qkv_stride + h * head_dim + i * 2;
2881    } else {
2882        // K head
2883        uint kp = pair - q_pairs;
2884        h = kp / half_dim;
2885        i = kp % half_dim;
2886        uint k_start = num_q_heads * head_dim;
2887        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2888    }
2889
2890    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2891    float angle = float(pos) * freq;
2892    float cos_val = cos(angle);
2893    float sin_val = sin(angle);
2894
2895    float x0 = qkv_data[offset];
2896    float x1 = qkv_data[offset + 1];
2897    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2898    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2899}
2900
2901// ── copy_kv_both_batch ────────────────────────────────────────────────
2902// Fused K+V cache copy in a single dispatch: copies both K and V from
2903// the strided batch QKV buffer to their respective KV cache buffers.
2904// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2905kernel void copy_kv_both_batch(
2906    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride] f32
2907    device half* k_dst         [[buffer(1)]],  // K cache [max_seq, kv_dim] f16
2908    device half* v_dst         [[buffer(2)]],  // V cache [max_seq, kv_dim] f16
2909    constant uint& M           [[buffer(3)]],  // num tokens in batch
2910    constant uint& kv_dim      [[buffer(4)]],  // elements per KV vector
2911    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2912    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2913    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2914    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2915    uint id [[thread_position_in_grid]])
2916{
2917    // Total elements = M * kv_dim * 2 (K + V)
2918    uint total_kv = M * kv_dim;
2919    if (id >= total_kv * 2) return;
2920
2921    uint is_v = id / total_kv;        // 0 = K, 1 = V
2922    uint local_id = id % total_kv;
2923    uint token = local_id / kv_dim;
2924    uint d = local_id % kv_dim;
2925
2926    uint dst_off = (base_pos + token) * kv_dim + d;
2927    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2928
2929    if (is_v) {
2930        v_dst[dst_off] = half(src[src_off]);
2931    } else {
2932        k_dst[dst_off] = half(src[src_off]);
2933    }
2934}
2935"#
2936    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2937    .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2938}
2939
2940// ---------------------------------------------------------------------------
2941// model.rs generation
2942// ---------------------------------------------------------------------------
2943
2944fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2945    let mut code = String::with_capacity(48 * 1024);
2946    emit_model_header(&mut code, config)?;
2947    emit_metal_model_struct(&mut code, config)?;
2948    emit_layer_buffers_struct(&mut code, config)?;
2949    emit_metal_model_impl(&mut code, config)?;
2950    emit_helper_functions(&mut code)?;
2951    Ok(code)
2952}
2953
2954fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2955    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2956    writeln!(
2957        code,
2958        "//! Model: {} ({} layers, hidden={})",
2959        config.architecture, config.num_layers, config.hidden_size
2960    )?;
2961    writeln!(code, "//!")?;
2962    writeln!(
2963        code,
2964        "//! Uses native Metal compute pipelines via the metal crate."
2965    )?;
2966    writeln!(code)?;
2967    writeln!(code, "#![allow(dead_code)]")?;
2968    writeln!(code)?;
2969    writeln!(code, "use metal::*;")?;
2970    writeln!(code, "#[allow(unused_imports)]")?;
2971    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2972    writeln!(code, "use std::mem;")?;
2973    writeln!(code)?;
2974
2975    // Model constants
2976    writeln!(
2977        code,
2978        "// ── Model constants ──────────────────────────────────"
2979    )?;
2980    writeln!(
2981        code,
2982        "pub const HIDDEN_SIZE: usize = {};",
2983        config.hidden_size
2984    )?;
2985    writeln!(
2986        code,
2987        "pub const INTERMEDIATE_SIZE: usize = {};",
2988        config.intermediate_size
2989    )?;
2990    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2991    writeln!(
2992        code,
2993        "pub const NUM_HEADS: usize = {};",
2994        config.num_attention_heads
2995    )?;
2996    writeln!(
2997        code,
2998        "pub const NUM_KV_HEADS: usize = {};",
2999        config.num_kv_heads
3000    )?;
3001    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
3002    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
3003    let effective_seq_len = config.max_seq_len.min(4096);
3004    writeln!(
3005        code,
3006        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
3007        effective_seq_len, config.max_seq_len
3008    )?;
3009    writeln!(
3010        code,
3011        "pub const RMS_NORM_EPS: f32 = {:e};",
3012        config.rms_norm_eps
3013    )?;
3014    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3015    writeln!(
3016        code,
3017        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3018    )?;
3019    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3020    writeln!(code)?;
3021
3022    Ok(())
3023}
3024
3025fn emit_metal_model_struct(
3026    code: &mut String,
3027    config: &ModelConfig,
3028) -> Result<(), MetalCodegenError> {
3029    writeln!(
3030        code,
3031        "// ── MetalModel ──────────────────────────────────────────"
3032    )?;
3033    writeln!(code)?;
3034    writeln!(
3035        code,
3036        "/// Metal-accelerated transformer model for Apple Silicon."
3037    )?;
3038    writeln!(code, "///")?;
3039    writeln!(
3040        code,
3041        "/// Uses unified memory for zero-copy weight access and native Metal"
3042    )?;
3043    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3044    writeln!(code, "pub struct MetalModel {{")?;
3045    writeln!(code, "    device: Device,")?;
3046    writeln!(code, "    queue: CommandQueue,")?;
3047    writeln!(code)?;
3048    writeln!(code, "    // ── Compute pipelines ──")?;
3049    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
3050    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
3051    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
3052    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
3053    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
3054    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
3055    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
3056    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
3057    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
3058    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
3059    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
3060    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
3061    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
3062    writeln!(
3063        code,
3064        "    copy_f32_to_f16_offset_pipeline: ComputePipelineState,"
3065    )?;
3066    writeln!(code)?;
3067    writeln!(code, "    // ── Batched prefill pipelines ──")?;
3068    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
3069    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
3070    writeln!(
3071        code,
3072        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3073    )?;
3074    writeln!(code, "    matmul_q8_mma_pipeline: ComputePipelineState,")?;
3075    writeln!(code, "    matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3076    writeln!(
3077        code,
3078        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3079    )?;
3080    writeln!(
3081        code,
3082        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3083    )?;
3084    writeln!(
3085        code,
3086        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3087    )?;
3088    if config.qkv_bias {
3089        writeln!(code, "    add_bias_batch_pipeline: ComputePipelineState,")?;
3090    }
3091    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
3092    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
3093    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
3094    writeln!(
3095        code,
3096        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
3097    )?;
3098    writeln!(
3099        code,
3100        "    add_inplace_batch_pipeline: ComputePipelineState,"
3101    )?;
3102    writeln!(
3103        code,
3104        "    copy_embedding_batch_pipeline: ComputePipelineState,"
3105    )?;
3106    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
3107    writeln!(
3108        code,
3109        "    attention_flash_batch_pipeline: ComputePipelineState,"
3110    )?;
3111    writeln!(
3112        code,
3113        "    attention_mma_flash_batch_pipeline: ComputePipelineState,"
3114    )?;
3115    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
3116    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
3117    writeln!(
3118        code,
3119        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
3120    )?;
3121    writeln!(code)?;
3122    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
3123    writeln!(
3124        code,
3125        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3126    )?;
3127    writeln!(code, "    embed_buf: Buffer,")?;
3128    writeln!(code)?;
3129    writeln!(code, "    /// Per-layer weight buffers")?;
3130    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
3131    writeln!(code)?;
3132    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
3133    writeln!(code, "    norm_buf: Buffer,")?;
3134    writeln!(code)?;
3135    writeln!(
3136        code,
3137        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3138    )?;
3139    writeln!(code, "    lm_head_buf: Buffer,")?;
3140    writeln!(code)?;
3141    writeln!(
3142        code,
3143        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
3144    )?;
3145    writeln!(code, "    hidden_buf: Buffer,")?;
3146    writeln!(code, "    residual_buf: Buffer,")?;
3147    writeln!(code, "    normed_buf: Buffer,")?;
3148    writeln!(
3149        code,
3150        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3151    )?;
3152    writeln!(code, "    qkv_buf: Buffer,")?;
3153    writeln!(code, "    attn_out_buf: Buffer,")?;
3154    writeln!(code, "    attn_proj_buf: Buffer,")?;
3155    writeln!(
3156        code,
3157        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3158    )?;
3159    writeln!(code, "    gate_up_buf: Buffer,")?;
3160    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
3161    writeln!(code, "    ffn_out_buf: Buffer,")?;
3162    writeln!(code, "    add_tmp_buf: Buffer,")?;
3163    writeln!(code, "    logits_buf: Buffer,")?;
3164    writeln!(code)?;
3165    writeln!(code, "    // ── Batched prefill working buffers ──")?;
3166    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3167    writeln!(code, "    batch_hidden_buf: Buffer,")?;
3168    writeln!(
3169        code,
3170        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3171    )?;
3172    writeln!(code, "    batch_residual_buf: Buffer,")?;
3173    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3174    writeln!(code, "    batch_qkv_buf: Buffer,")?;
3175    writeln!(
3176        code,
3177        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3178    )?;
3179    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
3180    writeln!(
3181        code,
3182        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3183    )?;
3184    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
3185    writeln!(
3186        code,
3187        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3188    )?;
3189    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
3190    writeln!(
3191        code,
3192        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3193    )?;
3194    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
3195    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3196    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
3197    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
3198    writeln!(code, "    batch_tokens_buf: Buffer,")?;
3199    writeln!(code, "    /// Positions buffer for batch RoPE")?;
3200    writeln!(code, "    batch_positions_buf: Buffer,")?;
3201    writeln!(code)?;
3202    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
3203    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
3204    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
3205    writeln!(code)?;
3206    writeln!(code, "    // ── Inference state ──")?;
3207    writeln!(code, "    pos: usize,")?;
3208    writeln!(code)?;
3209    writeln!(
3210        code,
3211        "    /// Previous command buffer for double-buffered prefill."
3212    )?;
3213    writeln!(
3214        code,
3215        "    /// While the GPU executes token N, the CPU can encode token N+1."
3216    )?;
3217    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
3218    writeln!(code, "}}")?;
3219    writeln!(code)?;
3220
3221    Ok(())
3222}
3223
3224fn emit_layer_buffers_struct(
3225    code: &mut String,
3226    config: &ModelConfig,
3227) -> Result<(), MetalCodegenError> {
3228    writeln!(
3229        code,
3230        "/// Per-layer weight buffers for attention and FFN projections."
3231    )?;
3232    writeln!(code, "struct LayerBuffers {{")?;
3233    writeln!(code, "    attn_norm: Buffer,")?;
3234    writeln!(
3235        code,
3236        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3237    )?;
3238    writeln!(code, "    qkv_weight: Buffer,")?;
3239    if config.qkv_bias {
3240        writeln!(
3241            code,
3242            "    /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3243        )?;
3244        writeln!(code, "    qkv_bias: Buffer,")?;
3245    }
3246    writeln!(code, "    o_weight: Buffer,")?;
3247    writeln!(code, "    ffn_norm: Buffer,")?;
3248    writeln!(
3249        code,
3250        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3251    )?;
3252    writeln!(code, "    gate_up_weight: Buffer,")?;
3253    writeln!(code, "    down_weight: Buffer,")?;
3254    writeln!(code, "}}")?;
3255    writeln!(code)?;
3256
3257    Ok(())
3258}
3259
3260fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3261    let hidden = config.hidden_size;
3262    let intermediate = config.intermediate_size;
3263    let _num_layers = config.num_layers;
3264    let _num_heads = config.num_attention_heads;
3265    let num_kv_heads = config.num_kv_heads;
3266    let head_dim = config.head_dim;
3267    let vocab = config.vocab_size;
3268    let effective_seq_len = config.max_seq_len.min(4096);
3269    let is_q8 = config.dtype == DType::Q8_0;
3270    let is_q4 = config.dtype == DType::Q4_0;
3271    let kv_dim = num_kv_heads * head_dim;
3272
3273    writeln!(code, "impl MetalModel {{")?;
3274
3275    // ── new() ──
3276    writeln!(
3277        code,
3278        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3279    )?;
3280    writeln!(code, "    ///")?;
3281    writeln!(
3282        code,
3283        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
3284    )?;
3285    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
3286    writeln!(
3287        code,
3288        "        let device = Device::system_default().expect(\"no Metal device found\");"
3289    )?;
3290    writeln!(code, "        let queue = device.new_command_queue();")?;
3291    writeln!(code)?;
3292
3293    // Compile shaders
3294    writeln!(
3295        code,
3296        "        // Compile Metal shaders from embedded source"
3297    )?;
3298    writeln!(
3299        code,
3300        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
3301    )?;
3302    writeln!(
3303        code,
3304        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3305    )?;
3306    writeln!(
3307        code,
3308        "            .expect(\"failed to compile Metal shaders\");"
3309    )?;
3310    writeln!(code)?;
3311
3312    // Create compute pipelines
3313    writeln!(code, "        // Create compute pipelines")?;
3314    for (var, fn_name) in [
3315        ("matmul_pipeline", "matmul_vec"),
3316        ("matmul_q8_pipeline", "matmul_vec_q8"),
3317        ("matmul_q4_pipeline", "matmul_vec_q4"),
3318        ("rms_norm_pipeline", "rms_norm"),
3319        ("rope_pipeline", "rope"),
3320        ("softmax_pipeline", "softmax"),
3321        ("silu_mul_pipeline", "silu_mul"),
3322        ("silu_mul_fused_pipeline", "silu_mul_fused"),
3323        ("add_pipeline", "elementwise_add"),
3324        ("attention_pipeline", "attention"),
3325        ("add_inplace_pipeline", "add_inplace"),
3326        ("copy_pipeline", "copy_buffer"),
3327        ("copy_offset_pipeline", "copy_offset"),
3328        ("copy_f32_to_f16_offset_pipeline", "copy_f32_to_f16_offset"),
3329        ("matmul_batch_pipeline", "matmul_vec_batch"),
3330        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3331        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3332        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3333        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3334        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3335        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3336        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3337        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3338        ("rms_norm_batch_pipeline", "rms_norm_batch"),
3339        ("rope_batch_pipeline", "rope_batch"),
3340        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3341        ("add_inplace_batch_pipeline", "add_inplace_batch"),
3342        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3343        ("attention_batch_pipeline", "attention_batch"),
3344        ("attention_flash_batch_pipeline", "attention_flash_batch"),
3345        (
3346            "attention_mma_flash_batch_pipeline",
3347            "attention_mma_flash_batch",
3348        ),
3349        ("copy_kv_batch_pipeline", "copy_kv_batch"),
3350        ("rope_qk_batch_pipeline", "rope_qk_batch"),
3351        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3352    ] {
3353        writeln!(
3354            code,
3355            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3356        )?;
3357    }
3358    if config.qkv_bias {
3359        writeln!(
3360            code,
3361            "        let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3362        )?;
3363    }
3364    writeln!(code)?;
3365
3366    // Weight loading
3367    writeln!(
3368        code,
3369        "        // Load weights into Metal shared-memory buffers"
3370    )?;
3371    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
3372    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3373    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
3374    writeln!(code)?;
3375    writeln!(
3376        code,
3377        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
3378    )?;
3379    writeln!(code)?;
3380    writeln!(
3381        code,
3382        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3383    )?;
3384    writeln!(
3385        code,
3386        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3387    )?;
3388    writeln!(code, "            let byte_len = n * f32_size;")?;
3389    writeln!(code, "            let cur = cursor.get();")?;
3390    writeln!(
3391        code,
3392        "            let data = &weights[cur..cur + byte_len];"
3393    )?;
3394    writeln!(code, "            cursor.set(cur + byte_len);")?;
3395    writeln!(code, "            device.new_buffer_with_data(")?;
3396    writeln!(code, "                data.as_ptr() as *const _,")?;
3397    writeln!(code, "                byte_len as u64,")?;
3398    writeln!(
3399        code,
3400        "                MTLResourceOptions::StorageModeShared,"
3401    )?;
3402    writeln!(code, "            )")?;
3403    writeln!(code, "        }};")?;
3404    writeln!(code)?;
3405
3406    if is_q8 {
3407        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
3408        // We load them directly into Metal buffers without dequantizing,
3409        // and use the matmul_vec_q8 shader that operates on quantized data.
3410        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
3411        writeln!(
3412            code,
3413            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3414        )?;
3415        writeln!(
3416            code,
3417            "        // as raw bytes into a Metal buffer (no dequantization)."
3418        )?;
3419        writeln!(
3420            code,
3421            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3422        )?;
3423        writeln!(
3424            code,
3425            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3426        )?;
3427        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3428        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3429        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3430        writeln!(code, "            let cur = cursor.get();")?;
3431        writeln!(
3432            code,
3433            "            let data = &weights[cur..cur + total_raw];"
3434        )?;
3435        writeln!(code, "            cursor.set(cur + total_raw);")?;
3436        writeln!(code, "            device.new_buffer_with_data(")?;
3437        writeln!(code, "                data.as_ptr() as *const _,")?;
3438        writeln!(code, "                total_raw as u64,")?;
3439        writeln!(
3440            code,
3441            "                MTLResourceOptions::StorageModeShared,"
3442        )?;
3443        writeln!(code, "            )")?;
3444        writeln!(code, "        }};")?;
3445        writeln!(code)?;
3446        writeln!(
3447            code,
3448            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3449        )?;
3450        writeln!(
3451            code,
3452            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3453        )?;
3454        writeln!(
3455            code,
3456            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3457        )?;
3458        writeln!(
3459            code,
3460            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3461        )?;
3462        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3463        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3464        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3465        writeln!(code, "            let cur = cursor.get();")?;
3466        writeln!(
3467            code,
3468            "            let data = &weights[cur..cur + total_raw];"
3469        )?;
3470        writeln!(code, "            cursor.set(cur + total_raw);")?;
3471        writeln!(code, "            device.new_buffer_with_data(")?;
3472        writeln!(code, "                data.as_ptr() as *const _,")?;
3473        writeln!(code, "                total_raw as u64,")?;
3474        writeln!(
3475            code,
3476            "                MTLResourceOptions::StorageModeShared,"
3477        )?;
3478        writeln!(code, "            )")?;
3479        writeln!(code, "        }};")?;
3480        writeln!(code)?;
3481    }
3482
3483    if is_q4 {
3484        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3485        // We load them directly into Metal buffers without dequantizing,
3486        // and use the matmul_vec_q4 shader that operates on quantized data.
3487        // This quarters GPU memory usage vs f32 dequantization.
3488        writeln!(
3489            code,
3490            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3491        )?;
3492        writeln!(
3493            code,
3494            "        // as raw bytes into a Metal buffer (no dequantization)."
3495        )?;
3496        writeln!(
3497            code,
3498            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3499        )?;
3500        writeln!(
3501            code,
3502            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3503        )?;
3504        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3505        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3506        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3507        writeln!(code, "            let cur = cursor.get();")?;
3508        writeln!(
3509            code,
3510            "            let data = &weights[cur..cur + total_raw];"
3511        )?;
3512        writeln!(code, "            cursor.set(cur + total_raw);")?;
3513        writeln!(code, "            device.new_buffer_with_data(")?;
3514        writeln!(code, "                data.as_ptr() as *const _,")?;
3515        writeln!(code, "                total_raw as u64,")?;
3516        writeln!(
3517            code,
3518            "                MTLResourceOptions::StorageModeShared,"
3519        )?;
3520        writeln!(code, "            )")?;
3521        writeln!(code, "        }};")?;
3522        writeln!(code)?;
3523        writeln!(
3524            code,
3525            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3526        )?;
3527        writeln!(
3528            code,
3529            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3530        )?;
3531        writeln!(
3532            code,
3533            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3534        )?;
3535        writeln!(
3536            code,
3537            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3538        )?;
3539        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3540        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3541        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3542        writeln!(code, "            let cur = cursor.get();")?;
3543        writeln!(
3544            code,
3545            "            let data = &weights[cur..cur + total_raw];"
3546        )?;
3547        writeln!(code, "            cursor.set(cur + total_raw);")?;
3548        writeln!(code, "            device.new_buffer_with_data(")?;
3549        writeln!(code, "                data.as_ptr() as *const _,")?;
3550        writeln!(code, "                total_raw as u64,")?;
3551        writeln!(
3552            code,
3553            "                MTLResourceOptions::StorageModeShared,"
3554        )?;
3555        writeln!(code, "            )")?;
3556        writeln!(code, "        }};")?;
3557        writeln!(code)?;
3558    }
3559
3560    writeln!(
3561        code,
3562        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3563    )?;
3564    writeln!(code)?;
3565
3566    // Per-layer weights
3567    writeln!(
3568        code,
3569        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3570    )?;
3571    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3572
3573    // attn_norm is always f32
3574    writeln!(
3575        code,
3576        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3577    )?;
3578
3579    let qkv_rows = hidden + 2 * kv_dim;
3580    if is_q8 {
3581        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3582        writeln!(
3583            code,
3584            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3585        )?;
3586        if config.qkv_bias {
3587            writeln!(
3588                code,
3589                "            // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3590            )?;
3591            writeln!(
3592                code,
3593                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3594            )?;
3595        }
3596        writeln!(
3597            code,
3598            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3599        )?;
3600    } else if is_q4 {
3601        writeln!(
3602            code,
3603            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3604        )?;
3605        if config.qkv_bias {
3606            writeln!(
3607                code,
3608                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3609            )?;
3610        }
3611        writeln!(
3612            code,
3613            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3614        )?;
3615    } else {
3616        writeln!(
3617            code,
3618            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3619        )?;
3620        if config.qkv_bias {
3621            writeln!(
3622                code,
3623                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3624            )?;
3625        }
3626        writeln!(
3627            code,
3628            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3629        )?;
3630    }
3631
3632    // ffn_norm is always f32
3633    writeln!(
3634        code,
3635        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3636    )?;
3637
3638    let gate_up_rows = 2 * intermediate;
3639    if is_q8 {
3640        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3641        writeln!(
3642            code,
3643            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3644        )?;
3645        writeln!(
3646            code,
3647            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3648        )?;
3649    } else if is_q4 {
3650        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3651        writeln!(
3652            code,
3653            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3654        )?;
3655        writeln!(
3656            code,
3657            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3658        )?;
3659    } else {
3660        // Fused gate+up weight: read both as a single contiguous f32 buffer
3661        writeln!(
3662            code,
3663            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3664        )?;
3665        writeln!(
3666            code,
3667            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3668        )?;
3669    }
3670
3671    writeln!(code, "            layers.push(LayerBuffers {{")?;
3672    writeln!(code, "                attn_norm,")?;
3673    writeln!(code, "                qkv_weight,")?;
3674    if config.qkv_bias {
3675        writeln!(code, "                qkv_bias,")?;
3676    }
3677    writeln!(code, "                o_weight,")?;
3678    writeln!(code, "                ffn_norm,")?;
3679    writeln!(code, "                gate_up_weight,")?;
3680    writeln!(code, "                down_weight,")?;
3681    writeln!(code, "            }});")?;
3682    writeln!(code, "        }}")?;
3683    writeln!(code)?;
3684
3685    // final_norm is always f32
3686    writeln!(
3687        code,
3688        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3689    )?;
3690    writeln!(code)?;
3691
3692    // lm_head
3693    if is_q8 {
3694        writeln!(
3695            code,
3696            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3697        )?;
3698    } else if is_q4 {
3699        writeln!(
3700            code,
3701            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3702        )?;
3703    } else {
3704        writeln!(
3705            code,
3706            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3707        )?;
3708    }
3709    writeln!(code)?;
3710
3711    // Working buffers
3712    let hidden_bytes = hidden * 4;
3713    let _kv_dim_bytes = kv_dim * 4;
3714    let intermediate_bytes = intermediate * 4;
3715    let vocab_bytes = vocab * 4;
3716    // KV cache is stored as f16 (2 bytes/element) to halve attention memory
3717    // bandwidth in long-context decode.
3718    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 2;
3719
3720    writeln!(
3721        code,
3722        "        // Allocate working buffers (shared memory for zero-copy)"
3723    )?;
3724    writeln!(
3725        code,
3726        "        let opts = MTLResourceOptions::StorageModeShared;"
3727    )?;
3728    writeln!(
3729        code,
3730        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3731    )?;
3732    writeln!(
3733        code,
3734        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3735    )?;
3736    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3737    writeln!(
3738        code,
3739        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3740    )?;
3741    writeln!(
3742        code,
3743        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3744    )?;
3745    writeln!(
3746        code,
3747        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3748    )?;
3749    writeln!(
3750        code,
3751        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3752    )?;
3753    writeln!(
3754        code,
3755        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3756    )?;
3757    let gate_up_buf_bytes = 2 * intermediate * 4;
3758    writeln!(
3759        code,
3760        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3761    )?;
3762    writeln!(
3763        code,
3764        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3765    )?;
3766    writeln!(
3767        code,
3768        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3769    )?;
3770    writeln!(
3771        code,
3772        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3773    )?;
3774    writeln!(
3775        code,
3776        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3777    )?;
3778    writeln!(
3779        code,
3780        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3781    )?;
3782    writeln!(code)?;
3783
3784    // Batch prefill working buffers
3785    let batch_hidden_bytes = hidden * 4; // per-token
3786    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3787    let batch_gate_up_bytes = 2 * intermediate * 4;
3788    let batch_intermediate_bytes = intermediate * 4;
3789    writeln!(
3790        code,
3791        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3792    )?;
3793    writeln!(
3794        code,
3795        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3796    )?;
3797    writeln!(
3798        code,
3799        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3800    )?;
3801    writeln!(
3802        code,
3803        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3804    )?;
3805    writeln!(
3806        code,
3807        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3808    )?;
3809    writeln!(
3810        code,
3811        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3812    )?;
3813    writeln!(
3814        code,
3815        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3816    )?;
3817    writeln!(
3818        code,
3819        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3820    )?;
3821    writeln!(
3822        code,
3823        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3824    )?;
3825    writeln!(
3826        code,
3827        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3828    )?;
3829    writeln!(
3830        code,
3831        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3832    )?;
3833    writeln!(code)?;
3834
3835    // KV cache buffers
3836    writeln!(code, "        // KV cache buffers (per-layer)")?;
3837    writeln!(
3838        code,
3839        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3840    )?;
3841    writeln!(
3842        code,
3843        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3844    )?;
3845    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3846    writeln!(
3847        code,
3848        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3849    )?;
3850    writeln!(
3851        code,
3852        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3853    )?;
3854    writeln!(code, "        }}")?;
3855    writeln!(code)?;
3856
3857    writeln!(code, "        Self {{")?;
3858    writeln!(code, "            device,")?;
3859    writeln!(code, "            queue,")?;
3860    writeln!(code, "            matmul_pipeline,")?;
3861    writeln!(code, "            matmul_q8_pipeline,")?;
3862    writeln!(code, "            matmul_q4_pipeline,")?;
3863    writeln!(code, "            rms_norm_pipeline,")?;
3864    writeln!(code, "            rope_pipeline,")?;
3865    writeln!(code, "            softmax_pipeline,")?;
3866    writeln!(code, "            silu_mul_pipeline,")?;
3867    writeln!(code, "            silu_mul_fused_pipeline,")?;
3868    writeln!(code, "            add_pipeline,")?;
3869    writeln!(code, "            attention_pipeline,")?;
3870    writeln!(code, "            add_inplace_pipeline,")?;
3871    writeln!(code, "            copy_pipeline,")?;
3872    writeln!(code, "            copy_offset_pipeline,")?;
3873    writeln!(code, "            copy_f32_to_f16_offset_pipeline,")?;
3874    writeln!(code, "            matmul_batch_pipeline,")?;
3875    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3876    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3877    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3878    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3879    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3880    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3881    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3882    if config.qkv_bias {
3883        writeln!(code, "            add_bias_batch_pipeline,")?;
3884    }
3885    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3886    writeln!(code, "            rms_norm_batch_pipeline,")?;
3887    writeln!(code, "            rope_batch_pipeline,")?;
3888    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3889    writeln!(code, "            add_inplace_batch_pipeline,")?;
3890    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3891    writeln!(code, "            attention_batch_pipeline,")?;
3892    writeln!(code, "            attention_flash_batch_pipeline,")?;
3893    writeln!(code, "            attention_mma_flash_batch_pipeline,")?;
3894    writeln!(code, "            copy_kv_batch_pipeline,")?;
3895    writeln!(code, "            rope_qk_batch_pipeline,")?;
3896    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3897    writeln!(code, "            embed_buf,")?;
3898    writeln!(code, "            layers,")?;
3899    writeln!(code, "            norm_buf,")?;
3900    writeln!(code, "            lm_head_buf,")?;
3901    writeln!(code, "            hidden_buf,")?;
3902    writeln!(code, "            residual_buf,")?;
3903    writeln!(code, "            normed_buf,")?;
3904    writeln!(code, "            qkv_buf,")?;
3905    writeln!(code, "            attn_out_buf,")?;
3906    writeln!(code, "            attn_proj_buf,")?;
3907    writeln!(code, "            gate_up_buf,")?;
3908    writeln!(code, "            ffn_hidden_buf,")?;
3909    writeln!(code, "            ffn_out_buf,")?;
3910    writeln!(code, "            add_tmp_buf,")?;
3911    writeln!(code, "            logits_buf,")?;
3912    writeln!(code, "            batch_hidden_buf,")?;
3913    writeln!(code, "            batch_residual_buf,")?;
3914    writeln!(code, "            batch_qkv_buf,")?;
3915    writeln!(code, "            batch_attn_out_buf,")?;
3916    writeln!(code, "            batch_attn_proj_buf,")?;
3917    writeln!(code, "            batch_gate_up_buf,")?;
3918    writeln!(code, "            batch_ffn_hidden_buf,")?;
3919    writeln!(code, "            batch_ffn_out_buf,")?;
3920    writeln!(code, "            batch_tokens_buf,")?;
3921    writeln!(code, "            batch_positions_buf,")?;
3922    writeln!(code, "            k_cache,")?;
3923    writeln!(code, "            v_cache,")?;
3924    writeln!(code, "            pos: 0,")?;
3925    writeln!(code, "            prev_cmd: None,")?;
3926    writeln!(code, "        }}")?;
3927    writeln!(code, "    }}")?;
3928    writeln!(code)?;
3929
3930    // ── forward() ──
3931    writeln!(
3932        code,
3933        "    /// Run the forward pass for a single token at the current position."
3934    )?;
3935    writeln!(code, "    ///")?;
3936    writeln!(
3937        code,
3938        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3939    )?;
3940    writeln!(code, "    ///")?;
3941    writeln!(
3942        code,
3943        "    /// All GPU operations are encoded into a single command buffer and"
3944    )?;
3945    writeln!(
3946        code,
3947        "    /// committed once at the end, avoiding per-operation synchronization."
3948    )?;
3949    writeln!(
3950        code,
3951        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3952    )?;
3953    writeln!(
3954        code,
3955        "        // Wait for any pending prefill command buffer"
3956    )?;
3957    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3958    writeln!(code, "            prev.wait_until_completed();")?;
3959    writeln!(code, "        }}")?;
3960    writeln!(code)?;
3961    writeln!(code, "        let pos = self.pos;")?;
3962    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3963    writeln!(code)?;
3964
3965    // Single compute encoder for the entire forward pass — no blit encoder
3966    // transitions. Copy operations use compute copy kernels instead of blits.
3967    let matmul_fn = if is_q8 {
3968        "dispatch_matmul_q8"
3969    } else if is_q4 {
3970        "dispatch_matmul_q4"
3971    } else {
3972        "dispatch_matmul"
3973    };
3974
3975    writeln!(
3976        code,
3977        "        // Single compute encoder for the entire forward pass (no blit transitions)"
3978    )?;
3979    writeln!(code, "        {{")?;
3980    writeln!(
3981        code,
3982        "            let enc = cmd.new_compute_command_encoder();"
3983    )?;
3984    writeln!(code)?;
3985
3986    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
3987    writeln!(
3988        code,
3989        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3990    )?;
3991    writeln!(
3992        code,
3993        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3994    )?;
3995    writeln!(
3996        code,
3997        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3998    )?;
3999    writeln!(
4000        code,
4001        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
4002        hidden * 4,
4003    )?;
4004    writeln!(code, "            unsafe {{")?;
4005    writeln!(
4006        code,
4007        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
4008    )?;
4009    writeln!(
4010        code,
4011        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4012    )?;
4013    writeln!(
4014        code,
4015        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
4016    )?;
4017    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
4018    writeln!(
4019        code,
4020        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4021    )?;
4022    writeln!(code, "                    hidden_ptr,")?;
4023    writeln!(code, "                    HIDDEN_SIZE,")?;
4024    writeln!(code, "                );")?;
4025    writeln!(
4026        code,
4027        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4028    )?;
4029    writeln!(code, "            }}")?;
4030    writeln!(code)?;
4031
4032    // 2. Transformer layers
4033    writeln!(code, "            // 2. Transformer layers")?;
4034    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4035    writeln!(code)?;
4036    let q_byte_offset = 0usize;
4037    let k_float_offset = hidden;
4038    let v_float_offset = hidden + kv_dim;
4039
4040    writeln!(
4041        code,
4042        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
4043    )?;
4044    writeln!(
4045        code,
4046        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4047    )?;
4048    writeln!(
4049        code,
4050        "                // Fused Q+K+V matmul: single dispatch for all three projections"
4051    )?;
4052    writeln!(
4053        code,
4054        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4055    )?;
4056    if config.qkv_bias {
4057        writeln!(
4058            code,
4059            "                // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4060        )?;
4061        writeln!(
4062            code,
4063            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4064        )?;
4065    }
4066    writeln!(
4067        code,
4068        "                // Fused Q+K RoPE in one dispatch (saves 1 dispatch + barrier vs separate Q and K rope)"
4069    )?;
4070    writeln!(
4071        code,
4072        "                self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4073    )?;
4074    writeln!(code)?;
4075    writeln!(
4076        code,
4077        "                // Fused K+V cache update in one dispatch (f32 qkv_buf -> f16 KV cache)"
4078    )?;
4079    writeln!(
4080        code,
4081        "                self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4082    )?;
4083    writeln!(code)?;
4084    writeln!(
4085        code,
4086        "                // Attention using Q from qkv_buf (offset 0)"
4087    )?;
4088    writeln!(
4089        code,
4090        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4091    )?;
4092    writeln!(
4093        code,
4094        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4095    )?;
4096    writeln!(
4097        code,
4098        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4099    )?;
4100    writeln!(
4101        code,
4102        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4103    )?;
4104    writeln!(
4105        code,
4106        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4107    )?;
4108    writeln!(
4109        code,
4110        "                // Fused gate+up matmul: single dispatch for both projections"
4111    )?;
4112    writeln!(
4113        code,
4114        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4115    )?;
4116    writeln!(
4117        code,
4118        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4119    )?;
4120    writeln!(
4121        code,
4122        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4123    )?;
4124    writeln!(
4125        code,
4126        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4127    )?;
4128    writeln!(code, "            }}")?;
4129    writeln!(code)?;
4130
4131    // 3. Final RMS norm + logits
4132    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
4133    writeln!(
4134        code,
4135        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4136    )?;
4137    writeln!(
4138        code,
4139        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4140    )?;
4141    writeln!(code)?;
4142    writeln!(code, "            enc.end_encoding();")?;
4143    writeln!(code, "        }}")?;
4144    writeln!(code)?;
4145
4146    // 5. Single commit + wait, then read back logits
4147    writeln!(
4148        code,
4149        "        // 5. Commit all GPU work and wait for completion"
4150    )?;
4151    writeln!(code, "        cmd.commit();")?;
4152    writeln!(code, "        cmd.wait_until_completed();")?;
4153    writeln!(code)?;
4154    writeln!(code, "        // 6. Read back logits from GPU")?;
4155    writeln!(code, "        let logits = unsafe {{")?;
4156    writeln!(
4157        code,
4158        "            let ptr = self.logits_buf.contents() as *const f32;"
4159    )?;
4160    writeln!(
4161        code,
4162        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4163    )?;
4164    writeln!(code, "        }};")?;
4165    writeln!(code)?;
4166    writeln!(code, "        self.pos += 1;")?;
4167    writeln!(code, "        logits")?;
4168    writeln!(code, "    }}")?;
4169    writeln!(code)?;
4170
4171    // ── forward_profile: instrumented forward with per-operation timing ──
4172    writeln!(
4173        code,
4174        "    /// Profiling forward pass that prints per-stage GPU timing."
4175    )?;
4176    writeln!(code, "    ///")?;
4177    writeln!(
4178        code,
4179        "    /// Each stage is committed and waited on separately so that GPU timestamps"
4180    )?;
4181    writeln!(
4182        code,
4183        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
4184    )?;
4185    writeln!(
4186        code,
4187        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
4188    )?;
4189    writeln!(
4190        code,
4191        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4192    )?;
4193    writeln!(code, "        use std::time::Instant;")?;
4194    writeln!(code)?;
4195    writeln!(
4196        code,
4197        "        // Wait for any pending prefill command buffer"
4198    )?;
4199    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4200    writeln!(code, "            prev.wait_until_completed();")?;
4201    writeln!(code, "        }}")?;
4202    writeln!(code)?;
4203    writeln!(code, "        let pos = self.pos;")?;
4204    writeln!(code)?;
4205
4206    // Stage: embedding (CPU, no GPU)
4207    writeln!(
4208        code,
4209        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
4210    )?;
4211    writeln!(code, "        let t_embed = Instant::now();")?;
4212    writeln!(code, "        unsafe {{")?;
4213    writeln!(
4214        code,
4215        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
4216    )?;
4217    writeln!(
4218        code,
4219        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4220    )?;
4221    writeln!(
4222        code,
4223        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
4224    )?;
4225    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
4226    writeln!(
4227        code,
4228        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4229    )?;
4230    writeln!(code, "                hidden_ptr,")?;
4231    writeln!(code, "                HIDDEN_SIZE,")?;
4232    writeln!(code, "            );")?;
4233    writeln!(
4234        code,
4235        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4236    )?;
4237    writeln!(code, "        }}")?;
4238    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
4239    writeln!(code)?;
4240
4241    // Stage: Transformer layers (all together on GPU)
4242    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
4243    writeln!(code, "        let t_layers = Instant::now();")?;
4244    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4245    writeln!(code, "        {{")?;
4246    writeln!(
4247        code,
4248        "            let enc = cmd.new_compute_command_encoder();"
4249    )?;
4250    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4251    writeln!(
4252        code,
4253        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4254    )?;
4255    writeln!(
4256        code,
4257        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4258    )?;
4259    if config.qkv_bias {
4260        writeln!(
4261            code,
4262            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4263        )?;
4264    }
4265    writeln!(
4266        code,
4267        "                self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4268    )?;
4269    writeln!(
4270        code,
4271        "                self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4272    )?;
4273    writeln!(
4274        code,
4275        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4276    )?;
4277    writeln!(
4278        code,
4279        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4280    )?;
4281    writeln!(
4282        code,
4283        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4284    )?;
4285    writeln!(
4286        code,
4287        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4288    )?;
4289    writeln!(
4290        code,
4291        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4292    )?;
4293    writeln!(
4294        code,
4295        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4296    )?;
4297    writeln!(
4298        code,
4299        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4300    )?;
4301    writeln!(
4302        code,
4303        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4304    )?;
4305    writeln!(code, "            }}")?;
4306    writeln!(code, "            enc.end_encoding();")?;
4307    writeln!(code, "        }}")?;
4308    writeln!(code, "        cmd.commit();")?;
4309    writeln!(code, "        cmd.wait_until_completed();")?;
4310    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
4311    writeln!(code)?;
4312
4313    // Stage: Final norm + logits
4314    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
4315    writeln!(code, "        let t_logits = Instant::now();")?;
4316    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
4317    writeln!(code, "        {{")?;
4318    writeln!(
4319        code,
4320        "            let enc = cmd2.new_compute_command_encoder();"
4321    )?;
4322    writeln!(
4323        code,
4324        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4325    )?;
4326    writeln!(
4327        code,
4328        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4329    )?;
4330    writeln!(code, "            enc.end_encoding();")?;
4331    writeln!(code, "        }}")?;
4332    writeln!(code, "        cmd2.commit();")?;
4333    writeln!(code, "        cmd2.wait_until_completed();")?;
4334    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
4335    writeln!(code)?;
4336
4337    // Print profile results
4338    writeln!(
4339        code,
4340        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4341    )?;
4342    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
4343    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
4344    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
4345    writeln!(
4346        code,
4347        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4348    )?;
4349    writeln!(code)?;
4350
4351    // Read back logits
4352    writeln!(code, "        let logits = unsafe {{")?;
4353    writeln!(
4354        code,
4355        "            let ptr = self.logits_buf.contents() as *const f32;"
4356    )?;
4357    writeln!(
4358        code,
4359        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4360    )?;
4361    writeln!(code, "        }};")?;
4362    writeln!(code)?;
4363    writeln!(code, "        self.pos += 1;")?;
4364    writeln!(code, "        logits")?;
4365    writeln!(code, "    }}")?;
4366    writeln!(code)?;
4367
4368    // ── forward_prefill: single-token async forward (backward compat) ──
4369    writeln!(
4370        code,
4371        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
4372    )?;
4373    writeln!(code, "    ///")?;
4374    writeln!(
4375        code,
4376        "    /// Commits the command buffer without waiting, enabling double-buffered"
4377    )?;
4378    writeln!(
4379        code,
4380        "    /// execution: GPU processes token N while CPU encodes token N+1."
4381    )?;
4382    writeln!(
4383        code,
4384        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
4385    )?;
4386    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
4387    writeln!(code, "    }}")?;
4388    writeln!(code)?;
4389
4390    // ── forward_prefill_batch: batched prefill for multiple tokens ──
4391    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
4392    let batch_matmul_fn = if is_q8 {
4393        "dispatch_matmul_q8_batch"
4394    } else if is_q4 {
4395        "dispatch_matmul_q4_batch"
4396    } else {
4397        "dispatch_matmul_batch"
4398    };
4399
4400    writeln!(
4401        code,
4402        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4403    )?;
4404    writeln!(code, "    ///")?;
4405    writeln!(
4406        code,
4407        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4408    )?;
4409    writeln!(
4410        code,
4411        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
4412    )?;
4413    writeln!(
4414        code,
4415        "    /// This provides significant speedup during prompt prefill."
4416    )?;
4417    writeln!(
4418        code,
4419        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4420    )?;
4421    writeln!(code, "        if tokens.is_empty() {{ return; }}")?;
4422    writeln!(
4423        code,
4424        "        // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4425    )?;
4426    writeln!(
4427        code,
4428        "        // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4429    )?;
4430    writeln!(
4431        code,
4432        "        // longer than that must be processed iteratively.  The KV cache"
4433    )?;
4434    writeln!(code, "        // carries state across chunks via self.pos.")?;
4435    writeln!(
4436        code,
4437        "        for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4438    )?;
4439    writeln!(code, "        let m = chunk.len();")?;
4440    writeln!(code, "        if m == 0 {{ continue; }}")?;
4441    writeln!(code, "        let start_pos = self.pos;")?;
4442    writeln!(code)?;
4443    writeln!(code, "        // Wait for any pending command buffer")?;
4444    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4445    writeln!(code, "            prev.wait_until_completed();")?;
4446    writeln!(code, "        }}")?;
4447    writeln!(code)?;
4448
4449    // Upload token IDs and positions to GPU
4450    writeln!(
4451        code,
4452        "        // Upload token IDs and positions to GPU buffers"
4453    )?;
4454    writeln!(code, "        unsafe {{")?;
4455    writeln!(
4456        code,
4457        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4458    )?;
4459    writeln!(
4460        code,
4461        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4462    )?;
4463    writeln!(code, "            for i in 0..m {{")?;
4464    writeln!(code, "                *tok_ptr.add(i) = chunk[i];")?;
4465    writeln!(
4466        code,
4467        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
4468    )?;
4469    writeln!(code, "            }}")?;
4470    writeln!(code, "        }}")?;
4471    writeln!(code)?;
4472
4473    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4474    writeln!(code, "        {{")?;
4475    writeln!(
4476        code,
4477        "            let enc = cmd.new_compute_command_encoder();"
4478    )?;
4479    writeln!(code)?;
4480
4481    // 1. Batch embedding lookup
4482    writeln!(
4483        code,
4484        "            // 1. Batch embedding lookup: copy all token embeddings at once"
4485    )?;
4486    writeln!(
4487        code,
4488        "            self.dispatch_copy_embedding_batch(&enc, m);"
4489    )?;
4490    // Copy batch_hidden -> batch_residual
4491    writeln!(
4492        code,
4493        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4494    )?;
4495    writeln!(code)?;
4496
4497    // 2. Transformer layers
4498    writeln!(code, "            // 2. Transformer layers")?;
4499    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4500    writeln!(code)?;
4501
4502    // Batch RMS norm: residual -> hidden (batched)
4503    writeln!(
4504        code,
4505        "                // Batch RMS norm: batch_residual -> batch_hidden"
4506    )?;
4507    writeln!(
4508        code,
4509        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4510    )?;
4511
4512    // Batch QKV matmul
4513    writeln!(
4514        code,
4515        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4516    )?;
4517    writeln!(
4518        code,
4519        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4520    )?;
4521    if config.qkv_bias {
4522        writeln!(
4523            code,
4524            "                // Qwen2: broadcast-add QKV bias across all M tokens."
4525        )?;
4526        writeln!(
4527            code,
4528            "                self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4529        )?;
4530    }
4531    writeln!(code)?;
4532
4533    // Fused RoPE on Q+K portions in a single dispatch
4534    let k_float_offset = hidden;
4535    writeln!(
4536        code,
4537        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4538    )?;
4539    writeln!(
4540        code,
4541        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4542    )?;
4543    writeln!(code)?;
4544
4545    // Fused KV cache update: copy both K and V in a single dispatch
4546    let v_float_offset = hidden + kv_dim;
4547    writeln!(
4548        code,
4549        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4550    )?;
4551    writeln!(
4552        code,
4553        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4554    )?;
4555    writeln!(code)?;
4556
4557    // Batched causal attention: ONE dispatch for all M tokens
4558    writeln!(
4559        code,
4560        "                // Batched causal attention: one dispatch for all M tokens"
4561    )?;
4562    writeln!(
4563        code,
4564        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4565    )?;
4566    writeln!(code)?;
4567
4568    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4569    writeln!(code, "                // Batched O projection")?;
4570    writeln!(
4571        code,
4572        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4573    )?;
4574    writeln!(code)?;
4575
4576    // Batch add: residual += attn_proj for all tokens
4577    writeln!(
4578        code,
4579        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4580    )?;
4581    writeln!(code)?;
4582
4583    // Batch FFN
4584    writeln!(
4585        code,
4586        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4587    )?;
4588    writeln!(
4589        code,
4590        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4591    )?;
4592    writeln!(
4593        code,
4594        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4595    )?;
4596    writeln!(
4597        code,
4598        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4599    )?;
4600    writeln!(
4601        code,
4602        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4603    )?;
4604    writeln!(
4605        code,
4606        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4607    )?;
4608    writeln!(code, "            }}")?;
4609    writeln!(code)?;
4610
4611    // Copy last token's residual to single-token residual_buf for next forward() call
4612    writeln!(
4613        code,
4614        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4615    )?;
4616    writeln!(
4617        code,
4618        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4619    )?;
4620    writeln!(code)?;
4621    writeln!(code, "            enc.end_encoding();")?;
4622    writeln!(code, "        }}")?;
4623    writeln!(code)?;
4624
4625    writeln!(code, "        cmd.commit();")?;
4626    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4627    writeln!(code, "        self.pos += m;")?;
4628    writeln!(code, "        }}  // end for chunk")?;
4629    writeln!(code, "    }}")?;
4630    writeln!(code)?;
4631
4632    // ── reset() — rewind KV cache position for new inference requests ──
4633    writeln!(
4634        code,
4635        "    /// Reset the model state for a new inference request."
4636    )?;
4637    writeln!(code, "    pub fn reset(&mut self) {{")?;
4638    writeln!(code, "        self.pos = 0;")?;
4639    writeln!(code, "        self.prev_cmd = None;")?;
4640    writeln!(code, "    }}")?;
4641    writeln!(code)?;
4642
4643    // ── Private dispatch helpers (all take a shared compute encoder) ──
4644    writeln!(
4645        code,
4646        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4647    )?;
4648    writeln!(
4649        code,
4650        "    // These methods set pipeline state + buffers + dispatch on an existing"
4651    )?;
4652    writeln!(
4653        code,
4654        "    // encoder, avoiding per-operation encoder creation overhead."
4655    )?;
4656    writeln!(code)?;
4657
4658    // dispatch_rms_norm
4659    writeln!(
4660        code,
4661        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4662    )?;
4663    writeln!(
4664        code,
4665        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4666    )?;
4667    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4668    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4669    writeln!(
4670        code,
4671        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4672    )?;
4673    writeln!(
4674        code,
4675        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4676    )?;
4677    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4678    writeln!(
4679        code,
4680        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4681    )?;
4682    writeln!(
4683        code,
4684        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4685    )?;
4686    writeln!(
4687        code,
4688        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4689    )?;
4690    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4691    writeln!(
4692        code,
4693        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4694    )?;
4695    writeln!(
4696        code,
4697        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4698    )?;
4699    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4700    writeln!(code, "    }}")?;
4701    writeln!(code)?;
4702
4703    // dispatch_matmul
4704    writeln!(
4705        code,
4706        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4707    )?;
4708    writeln!(
4709        code,
4710        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4711    )?;
4712    writeln!(code, "        let r: u32 = rows as u32;")?;
4713    writeln!(code, "        let c: u32 = cols as u32;")?;
4714    writeln!(
4715        code,
4716        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4717    )?;
4718    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4719    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4720    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4721    writeln!(
4722        code,
4723        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4724    )?;
4725    writeln!(
4726        code,
4727        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4728    )?;
4729    writeln!(
4730        code,
4731        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4732    )?;
4733    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4734    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4735    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4736    writeln!(
4737        code,
4738        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4739    )?;
4740    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4741    writeln!(code, "    }}")?;
4742    writeln!(code)?;
4743
4744    // dispatch_matmul_q8
4745    writeln!(
4746        code,
4747        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4748    )?;
4749    writeln!(
4750        code,
4751        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4752    )?;
4753    writeln!(
4754        code,
4755        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4756    )?;
4757    writeln!(code, "        let r: u32 = rows as u32;")?;
4758    writeln!(code, "        let c: u32 = cols as u32;")?;
4759    writeln!(
4760        code,
4761        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4762    )?;
4763    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4764    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4765    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4766    writeln!(
4767        code,
4768        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4769    )?;
4770    writeln!(
4771        code,
4772        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4773    )?;
4774    writeln!(
4775        code,
4776        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4777    )?;
4778    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4779    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4780    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4781    writeln!(
4782        code,
4783        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4784    )?;
4785    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4786    writeln!(code, "    }}")?;
4787    writeln!(code)?;
4788
4789    // dispatch_matmul_q4
4790    writeln!(
4791        code,
4792        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4793    )?;
4794    writeln!(
4795        code,
4796        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4797    )?;
4798    writeln!(
4799        code,
4800        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4801    )?;
4802    writeln!(code, "        let r: u32 = rows as u32;")?;
4803    writeln!(code, "        let c: u32 = cols as u32;")?;
4804    writeln!(
4805        code,
4806        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4807    )?;
4808    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4809    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4810    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4811    writeln!(
4812        code,
4813        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4814    )?;
4815    writeln!(
4816        code,
4817        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4818    )?;
4819    writeln!(
4820        code,
4821        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4822    )?;
4823    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4824    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4825    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4826    writeln!(
4827        code,
4828        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4829    )?;
4830    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4831    writeln!(code, "    }}")?;
4832    writeln!(code)?;
4833
4834    // dispatch_rope
4835    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4836    writeln!(
4837        code,
4838        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4839    )?;
4840    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4841    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4842    writeln!(code, "        let p: u32 = pos as u32;")?;
4843    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4844    writeln!(
4845        code,
4846        "        let total_pairs = num_heads * (head_dim / 2);"
4847    )?;
4848    writeln!(
4849        code,
4850        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4851    )?;
4852    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4853    writeln!(
4854        code,
4855        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4856    )?;
4857    writeln!(
4858        code,
4859        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4860    )?;
4861    writeln!(
4862        code,
4863        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4864    )?;
4865    writeln!(
4866        code,
4867        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4868    )?;
4869    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4870    writeln!(
4871        code,
4872        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4873    )?;
4874    writeln!(
4875        code,
4876        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4877    )?;
4878    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4879    writeln!(code, "    }}")?;
4880    writeln!(code)?;
4881
4882    // dispatch_rope_offset
4883    writeln!(
4884        code,
4885        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4886    )?;
4887    writeln!(
4888        code,
4889        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4890    )?;
4891    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4892    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4893    writeln!(code, "        let p: u32 = pos as u32;")?;
4894    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4895    writeln!(
4896        code,
4897        "        let total_pairs = num_heads * (head_dim / 2);"
4898    )?;
4899    writeln!(
4900        code,
4901        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4902    )?;
4903    writeln!(
4904        code,
4905        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4906    )?;
4907    writeln!(
4908        code,
4909        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4910    )?;
4911    writeln!(
4912        code,
4913        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4914    )?;
4915    writeln!(
4916        code,
4917        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4918    )?;
4919    writeln!(
4920        code,
4921        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4922    )?;
4923    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4924    writeln!(
4925        code,
4926        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4927    )?;
4928    writeln!(
4929        code,
4930        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4931    )?;
4932    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4933    writeln!(code, "    }}")?;
4934    writeln!(code)?;
4935
4936    // dispatch_attention
4937    writeln!(
4938        code,
4939        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4940    )?;
4941    writeln!(
4942        code,
4943        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4944    )?;
4945    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4946    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4947    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4948    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4949    writeln!(
4950        code,
4951        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4952    )?;
4953    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4954    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4955    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4956    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4957    writeln!(
4958        code,
4959        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4960    )?;
4961    writeln!(
4962        code,
4963        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4964    )?;
4965    writeln!(
4966        code,
4967        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4968    )?;
4969    writeln!(
4970        code,
4971        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4972    )?;
4973    writeln!(
4974        code,
4975        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4976    )?;
4977    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4978    writeln!(
4979        code,
4980        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4981    )?;
4982    writeln!(
4983        code,
4984        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4985    )?;
4986    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4987    writeln!(code, "    }}")?;
4988    writeln!(code)?;
4989
4990    // dispatch_attention_offset
4991    writeln!(
4992        code,
4993        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4994    )?;
4995    writeln!(
4996        code,
4997        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4998    )?;
4999    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
5000    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5001    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5002    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5003    writeln!(
5004        code,
5005        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
5006    )?;
5007    writeln!(
5008        code,
5009        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5010    )?;
5011    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5012    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5013    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5014    writeln!(
5015        code,
5016        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5017    )?;
5018    writeln!(
5019        code,
5020        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5021    )?;
5022    writeln!(
5023        code,
5024        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5025    )?;
5026    writeln!(
5027        code,
5028        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5029    )?;
5030    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5031    writeln!(
5032        code,
5033        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5034    )?;
5035    writeln!(
5036        code,
5037        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5038    )?;
5039    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5040    writeln!(code, "    }}")?;
5041    writeln!(code)?;
5042
5043    // dispatch_silu_mul
5044    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
5045    writeln!(
5046        code,
5047        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5048    )?;
5049    writeln!(code, "        let count: u32 = n as u32;")?;
5050    writeln!(
5051        code,
5052        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5053    )?;
5054    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
5055    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
5056    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5057    writeln!(
5058        code,
5059        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5060    )?;
5061    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5062    writeln!(
5063        code,
5064        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5065    )?;
5066    writeln!(
5067        code,
5068        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5069    )?;
5070    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5071    writeln!(code, "    }}")?;
5072    writeln!(code)?;
5073
5074    // dispatch_silu_mul_fused
5075    writeln!(
5076        code,
5077        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5078    )?;
5079    writeln!(
5080        code,
5081        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
5082    )?;
5083    writeln!(
5084        code,
5085        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5086    )?;
5087    writeln!(code, "        let count: u32 = n as u32;")?;
5088    writeln!(
5089        code,
5090        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5091    )?;
5092    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5093    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5094    writeln!(
5095        code,
5096        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5097    )?;
5098    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5099    writeln!(
5100        code,
5101        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5102    )?;
5103    writeln!(
5104        code,
5105        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5106    )?;
5107    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5108    writeln!(code, "    }}")?;
5109    writeln!(code)?;
5110
5111    // dispatch_copy (simple src -> dst copy via compute kernel)
5112    writeln!(
5113        code,
5114        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5115    )?;
5116    writeln!(
5117        code,
5118        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5119    )?;
5120    writeln!(code, "        let n: u32 = count as u32;")?;
5121    writeln!(
5122        code,
5123        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5124    )?;
5125    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5126    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5127    writeln!(
5128        code,
5129        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5130    )?;
5131    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5132    writeln!(
5133        code,
5134        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5135    )?;
5136    writeln!(
5137        code,
5138        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5139    )?;
5140    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5141    writeln!(code, "    }}")?;
5142    writeln!(code)?;
5143
5144    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
5145    writeln!(
5146        code,
5147        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5148    )?;
5149    writeln!(
5150        code,
5151        "    /// Used for embedding table lookup (copy a specific row)."
5152    )?;
5153    writeln!(
5154        code,
5155        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5156    )?;
5157    writeln!(code, "        let off: u32 = src_offset as u32;")?;
5158    writeln!(code, "        let n: u32 = count as u32;")?;
5159    writeln!(
5160        code,
5161        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5162    )?;
5163    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5164    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5165    writeln!(
5166        code,
5167        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5168    )?;
5169    writeln!(
5170        code,
5171        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5172    )?;
5173    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5174    writeln!(
5175        code,
5176        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5177    )?;
5178    writeln!(
5179        code,
5180        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5181    )?;
5182    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5183    writeln!(code, "    }}")?;
5184    writeln!(code)?;
5185
5186    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
5187    writeln!(
5188        code,
5189        "    /// Dispatch copy from source at byte offset to destination at float offset."
5190    )?;
5191    writeln!(
5192        code,
5193        "    /// Used for KV cache updates from fused QKV buffer."
5194    )?;
5195    writeln!(
5196        code,
5197        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5198    )?;
5199    writeln!(code, "        let n: u32 = count as u32;")?;
5200    writeln!(
5201        code,
5202        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5203    )?;
5204    writeln!(
5205        code,
5206        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5207    )?;
5208    writeln!(
5209        code,
5210        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5211    )?;
5212    writeln!(
5213        code,
5214        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5215    )?;
5216    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5217    writeln!(
5218        code,
5219        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5220    )?;
5221    writeln!(
5222        code,
5223        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5224    )?;
5225    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5226    writeln!(code, "    }}")?;
5227    writeln!(code)?;
5228
5229    // dispatch_copy_from_offset_f16 (copy src[src_byte_offset..] -> f16 dst[dst_elem_offset..])
5230    writeln!(
5231        code,
5232        "    /// Dispatch f32 -> f16 copy from src at byte offset to half-typed dst at element offset."
5233    )?;
5234    writeln!(
5235        code,
5236        "    /// Used for single-token decode KV cache updates (f32 QKV buf -> f16 KV cache)."
5237    )?;
5238    writeln!(
5239        code,
5240        "    fn dispatch_copy_from_offset_f16(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_elem_offset: usize, count: usize) {{"
5241    )?;
5242    writeln!(code, "        let n: u32 = count as u32;")?;
5243    writeln!(
5244        code,
5245        "        enc.set_compute_pipeline_state(&self.copy_f32_to_f16_offset_pipeline);"
5246    )?;
5247    writeln!(
5248        code,
5249        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5250    )?;
5251    writeln!(
5252        code,
5253        "        enc.set_buffer(1, Some(dst), (dst_elem_offset * 2) as u64);"
5254    )?;
5255    writeln!(
5256        code,
5257        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5258    )?;
5259    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5260    writeln!(
5261        code,
5262        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5263    )?;
5264    writeln!(
5265        code,
5266        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5267    )?;
5268    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5269    writeln!(code, "    }}")?;
5270    writeln!(code)?;
5271
5272    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
5273    writeln!(
5274        code,
5275        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5276    )?;
5277    writeln!(
5278        code,
5279        "    /// Used for KV cache updates (write to a specific position in the cache)."
5280    )?;
5281    writeln!(
5282        code,
5283        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5284    )?;
5285    writeln!(code, "        let n: u32 = count as u32;")?;
5286    writeln!(
5287        code,
5288        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5289    )?;
5290    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5291    writeln!(
5292        code,
5293        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5294    )?;
5295    writeln!(
5296        code,
5297        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5298    )?;
5299    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5300    writeln!(
5301        code,
5302        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5303    )?;
5304    writeln!(
5305        code,
5306        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5307    )?;
5308    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5309    writeln!(code, "    }}")?;
5310    writeln!(code)?;
5311
5312    // dispatch_add_inplace (residual connection, no blit needed)
5313    writeln!(
5314        code,
5315        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5316    )?;
5317    writeln!(
5318        code,
5319        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5320    )?;
5321    writeln!(code, "        let count: u32 = n as u32;")?;
5322    writeln!(
5323        code,
5324        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5325    )?;
5326    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5327    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5328    writeln!(
5329        code,
5330        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5331    )?;
5332    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5333    writeln!(
5334        code,
5335        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5336    )?;
5337    writeln!(
5338        code,
5339        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5340    )?;
5341    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5342    writeln!(code, "    }}")?;
5343    writeln!(code)?;
5344
5345    // ── Batched prefill dispatch helpers ──
5346    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
5347    writeln!(code)?;
5348
5349    // dispatch_copy_embedding_batch
5350    writeln!(
5351        code,
5352        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
5353    )?;
5354    writeln!(
5355        code,
5356        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5357    )?;
5358    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
5359    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5360    writeln!(
5361        code,
5362        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5363    )?;
5364    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5365    writeln!(
5366        code,
5367        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5368    )?;
5369    writeln!(
5370        code,
5371        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5372    )?;
5373    writeln!(
5374        code,
5375        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5376    )?;
5377    writeln!(
5378        code,
5379        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5380    )?;
5381    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
5382    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5383    writeln!(
5384        code,
5385        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5386    )?;
5387    writeln!(
5388        code,
5389        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5390    )?;
5391    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5392    writeln!(code, "    }}")?;
5393    writeln!(code)?;
5394
5395    // dispatch_rms_norm_batch
5396    writeln!(
5397        code,
5398        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
5399    )?;
5400    writeln!(
5401        code,
5402        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5403    )?;
5404    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
5405    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
5406    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5407    writeln!(
5408        code,
5409        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5410    )?;
5411    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
5412    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
5413    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5414    writeln!(
5415        code,
5416        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5417    )?;
5418    writeln!(
5419        code,
5420        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5421    )?;
5422    writeln!(
5423        code,
5424        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5425    )?;
5426    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5427    writeln!(
5428        code,
5429        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5430    )?;
5431    writeln!(
5432        code,
5433        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5434    )?;
5435    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5436    writeln!(code, "    }}")?;
5437    writeln!(code)?;
5438
5439    // dispatch_matmul_batch (f32)
5440    writeln!(
5441        code,
5442        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5443    )?;
5444    writeln!(
5445        code,
5446        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5447    )?;
5448    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5449    writeln!(code, "        let r: u32 = rows as u32;")?;
5450    writeln!(code, "        let c: u32 = cols as u32;")?;
5451    writeln!(
5452        code,
5453        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5454    )?;
5455    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5456    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5457    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5458    writeln!(
5459        code,
5460        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5461    )?;
5462    writeln!(
5463        code,
5464        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5465    )?;
5466    writeln!(
5467        code,
5468        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5469    )?;
5470    writeln!(
5471        code,
5472        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
5473    )?;
5474    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5475    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5476    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5477    writeln!(
5478        code,
5479        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5480    )?;
5481    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5482    writeln!(code, "    }}")?;
5483    writeln!(code)?;
5484
5485    // dispatch_matmul_q8_batch
5486    writeln!(
5487        code,
5488        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5489    )?;
5490    writeln!(code, "    ///")?;
5491    writeln!(
5492        code,
5493        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5494    )?;
5495    writeln!(
5496        code,
5497        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5498    )?;
5499    writeln!(
5500        code,
5501        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5502    )?;
5503    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5504    writeln!(code, "        let r: u32 = rows as u32;")?;
5505    writeln!(code, "        let c: u32 = cols as u32;")?;
5506    writeln!(
5507        code,
5508        "        // Tile sizes must match the Metal shader constants."
5509    )?;
5510    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
5511    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
5512    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
5513    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
5514    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
5515    writeln!(
5516        code,
5517        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
5518    )?;
5519    writeln!(
5520        code,
5521        "        // Prefer the large 32×32 tile when the problem supports it — halves"
5522    )?;
5523    writeln!(
5524        code,
5525        "        // dispatch count and reuses each weight load across 32 tokens."
5526    )?;
5527    writeln!(
5528        code,
5529        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5530    )?;
5531    writeln!(
5532        code,
5533        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5534    )?;
5535    writeln!(
5536        code,
5537        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
5538    )?;
5539    writeln!(
5540        code,
5541        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5542    )?;
5543    writeln!(
5544        code,
5545        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
5546    )?;
5547    writeln!(
5548        code,
5549        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5550    )?;
5551    writeln!(
5552        code,
5553        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5554    )?;
5555    writeln!(
5556        code,
5557        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5558    )?;
5559    writeln!(
5560        code,
5561        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5562    )?;
5563    writeln!(code, "            let use_h4 = cols >= 2048;")?;
5564    writeln!(code, "            let pipe = if use_h4 {{")?;
5565    writeln!(code, "                if num_tokens >= 256 {{")?;
5566    writeln!(
5567        code,
5568        "                    &self.matmul_q8_mma32_hh4_pipeline"
5569    )?;
5570    writeln!(code, "                }} else {{")?;
5571    writeln!(
5572        code,
5573        "                    &self.matmul_q8_mma32_h4_pipeline"
5574    )?;
5575    writeln!(code, "                }}")?;
5576    writeln!(code, "            }} else {{")?;
5577    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
5578    writeln!(code, "            }};")?;
5579    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
5580    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5581    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5582    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5583    writeln!(
5584        code,
5585        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5586    )?;
5587    writeln!(
5588        code,
5589        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5590    )?;
5591    writeln!(
5592        code,
5593        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5594    )?;
5595    writeln!(code, "            let row_tgs = rows / MMA32_ROW_TILE;")?;
5596    writeln!(
5597        code,
5598        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5599    )?;
5600    writeln!(
5601        code,
5602        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5603    )?;
5604    writeln!(
5605        code,
5606        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5607    )?;
5608    writeln!(
5609        code,
5610        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5611    )?;
5612    writeln!(
5613        code,
5614        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5615    )?;
5616    writeln!(
5617        code,
5618        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5619    )?;
5620    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5621    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5622    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5623    writeln!(
5624        code,
5625        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5626    )?;
5627    writeln!(
5628        code,
5629        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5630    )?;
5631    writeln!(
5632        code,
5633        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5634    )?;
5635    writeln!(code, "            let row_tgs = rows / MMA_ROW_TILE;")?;
5636    writeln!(
5637        code,
5638        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5639    )?;
5640    writeln!(
5641        code,
5642        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5643    )?;
5644    writeln!(
5645        code,
5646        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5647    )?;
5648    writeln!(
5649        code,
5650        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5651    )?;
5652    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5653    writeln!(
5654        code,
5655        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5656    )?;
5657    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5658    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5659    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5660    writeln!(
5661        code,
5662        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5663    )?;
5664    writeln!(
5665        code,
5666        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5667    )?;
5668    writeln!(
5669        code,
5670        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5671    )?;
5672    writeln!(
5673        code,
5674        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5675    )?;
5676    writeln!(
5677        code,
5678        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5679    )?;
5680    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5681    writeln!(
5682        code,
5683        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5684    )?;
5685    writeln!(
5686        code,
5687        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5688    )?;
5689    writeln!(code, "        }} else {{")?;
5690    writeln!(
5691        code,
5692        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5693    )?;
5694    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5695    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5696    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5697    writeln!(
5698        code,
5699        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5700    )?;
5701    writeln!(
5702        code,
5703        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5704    )?;
5705    writeln!(
5706        code,
5707        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5708    )?;
5709    writeln!(
5710        code,
5711        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5712    )?;
5713    writeln!(
5714        code,
5715        "            let num_tg = (row_tgs * num_tokens) as u64;"
5716    )?;
5717    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5718    writeln!(
5719        code,
5720        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5721    )?;
5722    writeln!(
5723        code,
5724        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5725    )?;
5726    writeln!(code, "        }}")?;
5727    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5728    writeln!(code, "    }}")?;
5729    writeln!(code)?;
5730
5731    // dispatch_matmul_q4_batch
5732    writeln!(
5733        code,
5734        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5735    )?;
5736    writeln!(
5737        code,
5738        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5739    )?;
5740    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5741    writeln!(code, "        let r: u32 = rows as u32;")?;
5742    writeln!(code, "        let c: u32 = cols as u32;")?;
5743    writeln!(
5744        code,
5745        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5746    )?;
5747    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5748    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5749    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5750    writeln!(
5751        code,
5752        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5753    )?;
5754    writeln!(
5755        code,
5756        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5757    )?;
5758    writeln!(
5759        code,
5760        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5761    )?;
5762    writeln!(
5763        code,
5764        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5765    )?;
5766    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5767    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5768    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5769    writeln!(
5770        code,
5771        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5772    )?;
5773    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5774    writeln!(code, "    }}")?;
5775    writeln!(code)?;
5776
5777    // dispatch_add_bias_batch — Qwen2 QKV bias broadcast-add after fused qkv matmul.
5778    if config.qkv_bias {
5779        writeln!(
5780            code,
5781            "    /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5782        )?;
5783        writeln!(
5784            code,
5785            "    fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5786        )?;
5787        writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5788        writeln!(code, "        let r: u32 = rows as u32;")?;
5789        writeln!(
5790            code,
5791            "        enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5792        )?;
5793        writeln!(code, "        enc.set_buffer(0, Some(out), 0);")?;
5794        writeln!(code, "        enc.set_buffer(1, Some(bias), 0);")?;
5795        writeln!(
5796            code,
5797            "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5798        )?;
5799        writeln!(
5800            code,
5801            "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5802        )?;
5803        writeln!(code, "        let total = (num_tokens * rows) as u64;")?;
5804        writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5805        writeln!(
5806            code,
5807            "        let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5808        )?;
5809        writeln!(
5810            code,
5811            "        enc.dispatch_thread_groups(grid_size, tg_size);"
5812        )?;
5813        writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5814        writeln!(code, "    }}")?;
5815        writeln!(code)?;
5816    }
5817
5818    // dispatch_rope_batch
5819    writeln!(
5820        code,
5821        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5822    )?;
5823    writeln!(
5824        code,
5825        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5826    )?;
5827    writeln!(
5828        code,
5829        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5830    )?;
5831    writeln!(
5832        code,
5833        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5834    )?;
5835    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5836    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5837    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5838    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5839    writeln!(
5840        code,
5841        "        let pairs_per_token = num_heads * (head_dim / 2);"
5842    )?;
5843    writeln!(
5844        code,
5845        "        let total_pairs = num_tokens * pairs_per_token;"
5846    )?;
5847    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5848    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5849    // we need to pass the buffer at the right byte offset for each token's data.
5850    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5851    // but our layout is data[token * row_stride + data_float_offset + ...].
5852    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5853    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5854    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5855    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5856    //
5857    // Simplest approach: use the single-token rope kernel for each token in a loop.
5858    // This is still efficient because we're dispatching all within the same command encoder.
5859    writeln!(
5860        code,
5861        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5862    )?;
5863    writeln!(code, "        for t in 0..num_tokens {{")?;
5864    writeln!(
5865        code,
5866        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5867    )?;
5868    writeln!(
5869        code,
5870        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5871    )?;
5872    writeln!(
5873        code,
5874        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5875    )?;
5876    writeln!(
5877        code,
5878        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5879    )?;
5880    writeln!(
5881        code,
5882        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5883    )?;
5884    writeln!(
5885        code,
5886        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5887    )?;
5888    writeln!(
5889        code,
5890        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5891    )?;
5892    writeln!(
5893        code,
5894        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5895    )?;
5896    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5897    writeln!(
5898        code,
5899        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5900    )?;
5901    writeln!(
5902        code,
5903        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5904    )?;
5905    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5906    writeln!(code, "        }}")?;
5907    writeln!(code, "    }}")?;
5908    writeln!(code)?;
5909
5910    // dispatch_silu_mul_fused_batch
5911    writeln!(
5912        code,
5913        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5914    )?;
5915    writeln!(
5916        code,
5917        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5918    )?;
5919    writeln!(code, "        let count: u32 = n as u32;")?;
5920    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5921    writeln!(
5922        code,
5923        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5924    )?;
5925    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5926    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5927    writeln!(
5928        code,
5929        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5930    )?;
5931    writeln!(
5932        code,
5933        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5934    )?;
5935    writeln!(code, "        let total = num_tokens * n;")?;
5936    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5937    writeln!(
5938        code,
5939        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5940    )?;
5941    writeln!(
5942        code,
5943        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5944    )?;
5945    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5946    writeln!(code, "    }}")?;
5947    writeln!(code)?;
5948
5949    // dispatch_add_inplace_batch_n (add n elements in-place)
5950    writeln!(
5951        code,
5952        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5953    )?;
5954    writeln!(
5955        code,
5956        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5957    )?;
5958    writeln!(code, "        let count: u32 = total_n as u32;")?;
5959    writeln!(
5960        code,
5961        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5962    )?;
5963    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5964    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5965    writeln!(
5966        code,
5967        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5968    )?;
5969    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5970    writeln!(
5971        code,
5972        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5973    )?;
5974    writeln!(
5975        code,
5976        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5977    )?;
5978    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5979    writeln!(code, "    }}")?;
5980    writeln!(code)?;
5981
5982    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
5983    writeln!(
5984        code,
5985        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
5986    )?;
5987    writeln!(
5988        code,
5989        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5990    )?;
5991    writeln!(code, "        let n: u32 = count as u32;")?;
5992    writeln!(
5993        code,
5994        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5995    )?;
5996    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5997    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5998    writeln!(
5999        code,
6000        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6001    )?;
6002    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6003    writeln!(
6004        code,
6005        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6006    )?;
6007    writeln!(
6008        code,
6009        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6010    )?;
6011    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6012    writeln!(code, "    }}")?;
6013    writeln!(code)?;
6014
6015    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
6016    writeln!(
6017        code,
6018        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
6019    )?;
6020    writeln!(
6021        code,
6022        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6023    )?;
6024    writeln!(code, "        let n: u32 = count as u32;")?;
6025    writeln!(
6026        code,
6027        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6028    )?;
6029    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6030    writeln!(
6031        code,
6032        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6033    )?;
6034    writeln!(
6035        code,
6036        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6037    )?;
6038    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6039    writeln!(
6040        code,
6041        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6042    )?;
6043    writeln!(
6044        code,
6045        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6046    )?;
6047    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6048    writeln!(code, "    }}")?;
6049    writeln!(code)?;
6050
6051    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
6052    writeln!(
6053        code,
6054        "    /// Copy from src at byte offset to dst at float offset."
6055    )?;
6056    writeln!(
6057        code,
6058        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6059    )?;
6060    writeln!(code, "        let n: u32 = count as u32;")?;
6061    writeln!(
6062        code,
6063        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6064    )?;
6065    writeln!(
6066        code,
6067        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6068    )?;
6069    writeln!(
6070        code,
6071        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6072    )?;
6073    writeln!(
6074        code,
6075        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6076    )?;
6077    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6078    writeln!(
6079        code,
6080        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6081    )?;
6082    writeln!(
6083        code,
6084        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6085    )?;
6086    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6087    writeln!(code, "    }}")?;
6088    writeln!(code)?;
6089
6090    // dispatch_copy_kv_batch
6091    writeln!(
6092        code,
6093        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6094    )?;
6095    writeln!(
6096        code,
6097        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6098    )?;
6099    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6100    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6101    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6102    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6103    writeln!(code, "        let so: u32 = src_offset as u32;")?;
6104    writeln!(
6105        code,
6106        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6107    )?;
6108    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6109    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
6110    writeln!(
6111        code,
6112        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6113    )?;
6114    writeln!(
6115        code,
6116        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6117    )?;
6118    writeln!(
6119        code,
6120        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6121    )?;
6122    writeln!(
6123        code,
6124        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6125    )?;
6126    writeln!(
6127        code,
6128        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6129    )?;
6130    writeln!(code, "        let total = num_tokens * kv_dim;")?;
6131    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6132    writeln!(
6133        code,
6134        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6135    )?;
6136    writeln!(
6137        code,
6138        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6139    )?;
6140    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6141    writeln!(code, "    }}")?;
6142    writeln!(code)?;
6143
6144    // dispatch_attention_batch
6145    writeln!(
6146        code,
6147        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
6148    )?;
6149    writeln!(
6150        code,
6151        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6152    )?;
6153    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6154    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6155    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
6156    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6157    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6158    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
6159    // Attention kernel selection:
6160    //   * Legacy `attention_batch` materializes scores[4096] in threadgroup memory
6161    //     and uses scalar simdgroup reductions.  Fast at short seq_len, no MMA.
6162    //   * `attention_flash_batch` streams K/V with online softmax; no seq cap,
6163    //     scalar math, ~7-14 % slower than legacy at long contexts (no MMA).
6164    //   * `attention_mma_flash_batch` adds hardware simdgroup_matrix<half, 8, 8>
6165    //     MMA for both Q·K^T and P·V, processing Q_BLOCK=8 tokens per TG.
6166    //     Default path when HEAD_DIM ≤ 128 and num_tokens ≥ 8 (verified on
6167    //     Llama, Qwen2.5, Phi-3).  Set FORGE_MMA_ATTN=0 to force legacy.
6168    writeln!(code, "        let max_seq = base_pos + num_tokens;")?;
6169    writeln!(code, "        let _ = max_seq;")?;
6170    writeln!(
6171        code,
6172        "        let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6173    )?;
6174    writeln!(code, "            .map(|v| v == \"0\").unwrap_or(false);")?;
6175    writeln!(
6176        code,
6177        "        let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6178    )?;
6179    writeln!(code, "        if use_mma_flash {{")?;
6180    writeln!(
6181        code,
6182        "            let pipe = &self.attention_mma_flash_batch_pipeline;"
6183    )?;
6184    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
6185    writeln!(code, "            enc.set_buffer(0, Some(q_buf), 0);")?;
6186    writeln!(code, "            enc.set_buffer(1, Some(k_cache), 0);")?;
6187    writeln!(code, "            enc.set_buffer(2, Some(v_cache), 0);")?;
6188    writeln!(code, "            enc.set_buffer(3, Some(output), 0);")?;
6189    writeln!(
6190        code,
6191        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6192    )?;
6193    writeln!(
6194        code,
6195        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6196    )?;
6197    writeln!(
6198        code,
6199        "            enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6200    )?;
6201    writeln!(
6202        code,
6203        "            enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6204    )?;
6205    writeln!(
6206        code,
6207        "            enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6208    )?;
6209    writeln!(
6210        code,
6211        "            enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6212    )?;
6213    writeln!(
6214        code,
6215        "            // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6216    )?;
6217    writeln!(code, "            let tg_size = MTLSize::new(128, 1, 1);")?;
6218    writeln!(
6219        code,
6220        "            let q_blocks = ((num_tokens + 7) / 8) as u64;"
6221    )?;
6222    writeln!(
6223        code,
6224        "            let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6225    )?;
6226    writeln!(
6227        code,
6228        "            enc.dispatch_thread_groups(grid_size, tg_size);"
6229    )?;
6230    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6231    writeln!(code, "            return;")?;
6232    writeln!(code, "        }}")?;
6233    writeln!(code, "        let pipe = &self.attention_batch_pipeline;")?;
6234    writeln!(code, "        enc.set_compute_pipeline_state(pipe);")?;
6235    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
6236    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
6237    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
6238    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
6239    writeln!(
6240        code,
6241        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6242    )?;
6243    writeln!(
6244        code,
6245        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6246    )?;
6247    writeln!(
6248        code,
6249        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6250    )?;
6251    writeln!(
6252        code,
6253        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6254    )?;
6255    writeln!(
6256        code,
6257        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6258    )?;
6259    writeln!(
6260        code,
6261        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6262    )?;
6263    writeln!(
6264        code,
6265        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6266    )?;
6267    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6268    writeln!(
6269        code,
6270        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6271    )?;
6272    writeln!(
6273        code,
6274        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6275    )?;
6276    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6277    writeln!(code, "    }}")?;
6278    writeln!(code)?;
6279
6280    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
6281    writeln!(
6282        code,
6283        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6284    )?;
6285    writeln!(
6286        code,
6287        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6288    )?;
6289    writeln!(
6290        code,
6291        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6292    )?;
6293    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6294    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6295    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
6296    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6297    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6298    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
6299    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
6300    writeln!(
6301        code,
6302        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6303    )?;
6304    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
6305    writeln!(
6306        code,
6307        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6308    )?;
6309    writeln!(
6310        code,
6311        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6312    )?;
6313    writeln!(
6314        code,
6315        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6316    )?;
6317    writeln!(
6318        code,
6319        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6320    )?;
6321    writeln!(
6322        code,
6323        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6324    )?;
6325    writeln!(
6326        code,
6327        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6328    )?;
6329    writeln!(
6330        code,
6331        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6332    )?;
6333    writeln!(
6334        code,
6335        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6336    )?;
6337    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6338    writeln!(
6339        code,
6340        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6341    )?;
6342    writeln!(
6343        code,
6344        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6345    )?;
6346    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6347    writeln!(code, "    }}")?;
6348    writeln!(code)?;
6349
6350    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
6351    writeln!(
6352        code,
6353        "    /// Dispatch fused K+V cache copy in one kernel launch."
6354    )?;
6355    writeln!(
6356        code,
6357        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6358    )?;
6359    writeln!(
6360        code,
6361        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6362    )?;
6363    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6364    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6365    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6366    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6367    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
6368    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
6369    writeln!(
6370        code,
6371        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6372    )?;
6373    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6374    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
6375    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
6376    writeln!(
6377        code,
6378        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6379    )?;
6380    writeln!(
6381        code,
6382        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6383    )?;
6384    writeln!(
6385        code,
6386        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6387    )?;
6388    writeln!(
6389        code,
6390        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6391    )?;
6392    writeln!(
6393        code,
6394        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6395    )?;
6396    writeln!(
6397        code,
6398        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6399    )?;
6400    writeln!(
6401        code,
6402        "        let total = num_tokens * kv_dim * 2;  // K + V"
6403    )?;
6404    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6405    writeln!(
6406        code,
6407        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6408    )?;
6409    writeln!(
6410        code,
6411        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6412    )?;
6413    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6414    writeln!(code, "    }}")?;
6415
6416    writeln!(code, "}}")?;
6417    writeln!(code)?;
6418
6419    Ok(())
6420}
6421
6422fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6423    writeln!(
6424        code,
6425        "// ── Helper functions ──────────────────────────────────"
6426    )?;
6427    writeln!(code)?;
6428    writeln!(
6429        code,
6430        "/// Create a compute pipeline from a named function in the Metal library."
6431    )?;
6432    writeln!(
6433        code,
6434        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6435    )?;
6436    writeln!(
6437        code,
6438        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6439    )?;
6440    writeln!(
6441        code,
6442        "    device.new_compute_pipeline_state_with_function(&func)"
6443    )?;
6444    writeln!(
6445        code,
6446        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6447    )?;
6448    writeln!(code, "}}")?;
6449    writeln!(code)?;
6450
6451    Ok(())
6452}
6453
6454// ---------------------------------------------------------------------------
6455// main.rs generation
6456// ---------------------------------------------------------------------------
6457
6458fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6459    let _sanitized = sanitize_name(model_name);
6460    let _vocab = config.vocab_size;
6461
6462    let mut code = String::with_capacity(16 * 1024);
6463    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6464    writeln!(
6465        code,
6466        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6467    )?;
6468    writeln!(code)?;
6469    writeln!(code, "mod model;")?;
6470    writeln!(code)?;
6471    writeln!(code, "use std::io::Write;")?;
6472    writeln!(code, "use std::time::Instant;")?;
6473    writeln!(code, "use serde::Deserialize;")?;
6474    writeln!(code)?;
6475
6476    // -- main function --
6477    writeln!(code, "fn main() {{")?;
6478    writeln!(
6479        code,
6480        "    let args: Vec<String> = std::env::args().collect();"
6481    )?;
6482    writeln!(code)?;
6483    writeln!(
6484        code,
6485        "    // Detect --serve mode (only requires weights + tokenizer)"
6486    )?;
6487    writeln!(
6488        code,
6489        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
6490    )?;
6491    writeln!(code)?;
6492    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
6493    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6494    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6495    writeln!(code, "        std::process::exit(1);")?;
6496    writeln!(code, "    }}")?;
6497    writeln!(code)?;
6498    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
6499    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6500    writeln!(code, "        std::process::exit(1);")?;
6501    writeln!(code, "    }}")?;
6502    writeln!(code)?;
6503    writeln!(code, "    let weights_path = &args[1];")?;
6504    writeln!(code, "    let tokenizer_path = &args[2];")?;
6505    writeln!(code)?;
6506    writeln!(code, "    // Parse optional flags")?;
6507    writeln!(code, "    let mut max_tokens: usize = 128;")?;
6508    writeln!(code, "    let mut port: u16 = 8080;")?;
6509    writeln!(
6510        code,
6511        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6512    )?;
6513    writeln!(
6514        code,
6515        "    let profile = args.iter().any(|a| a == \"--profile\");"
6516    )?;
6517    writeln!(code, "    let mut i = 3;")?;
6518    writeln!(code, "    while i < args.len() {{")?;
6519    writeln!(
6520        code,
6521        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6522    )?;
6523    writeln!(
6524        code,
6525        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
6526    )?;
6527    writeln!(code, "            i += 2;")?;
6528    writeln!(
6529        code,
6530        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6531    )?;
6532    writeln!(
6533        code,
6534        "            port = args[i + 1].parse().unwrap_or(8080);"
6535    )?;
6536    writeln!(code, "            i += 2;")?;
6537    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
6538    writeln!(code, "            i += 1;")?;
6539    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
6540    writeln!(code, "            i += 1;")?;
6541    writeln!(code, "        }} else {{")?;
6542    writeln!(code, "            i += 1;")?;
6543    writeln!(code, "        }}")?;
6544    writeln!(code, "    }}")?;
6545    writeln!(code)?;
6546
6547    // -- load model (shared by both modes) --
6548    writeln!(
6549        code,
6550        "    // Memory-map weights for zero-copy loading on Apple Silicon"
6551    )?;
6552    writeln!(
6553        code,
6554        "    let weights_file = std::fs::File::open(weights_path)"
6555    )?;
6556    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6557    writeln!(
6558        code,
6559        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6560    )?;
6561    writeln!(code)?;
6562    writeln!(code, "    // Load tokenizer")?;
6563    writeln!(
6564        code,
6565        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6566    )?;
6567    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6568    writeln!(code)?;
6569    writeln!(code, "    // Create Metal model")?;
6570    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
6571    writeln!(
6572        code,
6573        "    let mut model = model::MetalModel::new(&weights_mmap);"
6574    )?;
6575    writeln!(code)?;
6576
6577    // -- branch: serve vs CLI --
6578    writeln!(code, "    if serve_mode {{")?;
6579    writeln!(code, "        serve(model, tokenizer, port);")?;
6580    writeln!(code, "    }} else {{")?;
6581    writeln!(code, "        let prompt = &args[3];")?;
6582    writeln!(
6583        code,
6584        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6585    )?;
6586    writeln!(code, "    }}")?;
6587    writeln!(code, "}}")?;
6588    writeln!(code)?;
6589
6590    // -- cli_mode function --
6591    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6592    writeln!(code, "    // Tokenize prompt")?;
6593    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
6594    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6595    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6596    writeln!(code)?;
6597    writeln!(
6598        code,
6599        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6600    )?;
6601    writeln!(
6602        code,
6603        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
6604    )?;
6605    writeln!(
6606        code,
6607        "    // The last token uses synchronous forward() to get logits."
6608    )?;
6609    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
6610    writeln!(code, "    let prefill_start = Instant::now();")?;
6611    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
6612    writeln!(
6613        code,
6614        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6615    )?;
6616    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
6617    writeln!(code, "    }} else {{")?;
6618    writeln!(code, "        model.forward(prompt_tokens[0])")?;
6619    writeln!(code, "    }};")?;
6620    writeln!(
6621        code,
6622        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6623    )?;
6624    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
6625    writeln!(
6626        code,
6627        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6628    )?;
6629    writeln!(
6630        code,
6631        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6632    )?;
6633    writeln!(code)?;
6634    writeln!(code, "    // Generate tokens")?;
6635    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6636    writeln!(code, "    let gen_start = Instant::now();")?;
6637    writeln!(code, "    let mut generated_count: usize = 0;")?;
6638    writeln!(code)?;
6639    writeln!(code, "    for _ in 0..max_tokens {{")?;
6640    writeln!(
6641        code,
6642        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6643    )?;
6644    writeln!(code, "            if !quiet {{")?;
6645    writeln!(code, "                print!(\"{{}}\", text);")?;
6646    writeln!(code, "                std::io::stdout().flush().ok();")?;
6647    writeln!(code, "            }}")?;
6648    writeln!(code, "        }}")?;
6649    writeln!(code, "        generated_count += 1;")?;
6650    writeln!(code)?;
6651    writeln!(
6652        code,
6653        "        // Use profiling forward for first token when --profile is set"
6654    )?;
6655    writeln!(
6656        code,
6657        "        let logits = if profile && generated_count == 1 {{"
6658    )?;
6659    writeln!(code, "            model.forward_profile(next_token)")?;
6660    writeln!(code, "        }} else {{")?;
6661    writeln!(code, "            model.forward(next_token)")?;
6662    writeln!(code, "        }};")?;
6663    writeln!(code, "        next_token = argmax(&logits);")?;
6664    writeln!(code)?;
6665    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
6666    writeln!(code, "        if next_token == 2 {{")?;
6667    writeln!(code, "            break;")?;
6668    writeln!(code, "        }}")?;
6669    writeln!(code)?;
6670    writeln!(
6671        code,
6672        "        // Yield between tokens to reduce sustained GPU thermal load."
6673    )?;
6674    writeln!(
6675        code,
6676        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6677    )?;
6678    writeln!(
6679        code,
6680        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6681    )?;
6682    writeln!(
6683        code,
6684        "        // briefly, providing a micro-break that helps sustain peak throughput."
6685    )?;
6686    writeln!(code, "        std::thread::yield_now();")?;
6687    writeln!(code, "    }}")?;
6688    writeln!(code, "    if !quiet {{")?;
6689    writeln!(code, "        println!();")?;
6690    writeln!(code, "    }}")?;
6691    writeln!(
6692        code,
6693        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6694    )?;
6695    writeln!(
6696        code,
6697        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6698    )?;
6699    writeln!(
6700        code,
6701        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6702    )?;
6703    writeln!(code, "}}")?;
6704    writeln!(code)?;
6705
6706    // -- argmax helper --
6707    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6708    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6709    writeln!(code, "    logits.iter()")?;
6710    writeln!(code, "        .enumerate()")?;
6711    writeln!(
6712        code,
6713        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6714    )?;
6715    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6716    writeln!(code, "        .unwrap_or(0)")?;
6717    writeln!(code, "}}")?;
6718    writeln!(code)?;
6719
6720    // -- Request/Response types for OpenAI API --
6721    writeln!(
6722        code,
6723        "// -----------------------------------------------------------------------"
6724    )?;
6725    writeln!(code, "// OpenAI-compatible API server")?;
6726    writeln!(
6727        code,
6728        "// -----------------------------------------------------------------------"
6729    )?;
6730    writeln!(code)?;
6731    writeln!(code, "#[derive(Deserialize)]")?;
6732    writeln!(code, "struct ChatRequest {{")?;
6733    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6734    writeln!(code, "    #[serde(default)]")?;
6735    writeln!(code, "    stream: Option<bool>,")?;
6736    writeln!(code, "    #[serde(default)]")?;
6737    writeln!(code, "    max_tokens: Option<usize>,")?;
6738    writeln!(code, "    #[serde(default)]")?;
6739    writeln!(code, "    temperature: Option<f32>,")?;
6740    writeln!(code, "    #[serde(default)]")?;
6741    writeln!(code, "    model: Option<String>,")?;
6742    writeln!(code, "}}")?;
6743    writeln!(code)?;
6744    writeln!(code, "#[derive(Deserialize)]")?;
6745    writeln!(code, "struct ChatMessage {{")?;
6746    writeln!(code, "    role: String,")?;
6747    writeln!(code, "    content: String,")?;
6748    writeln!(code, "}}")?;
6749    writeln!(code)?;
6750
6751    // -- format_chat_messages --
6752    writeln!(
6753        code,
6754        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6755    )?;
6756    writeln!(code, "    let mut prompt = String::new();")?;
6757    writeln!(code, "    for msg in messages {{")?;
6758    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6759    writeln!(code, "    }}")?;
6760    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6761    writeln!(code, "    prompt")?;
6762    writeln!(code, "}}")?;
6763    writeln!(code)?;
6764
6765    // -- prefill helper --
6766    writeln!(
6767        code,
6768        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6769    )?;
6770    writeln!(code, "    let len = tokens.len();")?;
6771    writeln!(code, "    if len > 1 {{")?;
6772    writeln!(
6773        code,
6774        "        model.forward_prefill_batch(&tokens[..len - 1]);"
6775    )?;
6776    writeln!(code, "    }}")?;
6777    writeln!(code, "    model.forward(tokens[len - 1])")?;
6778    writeln!(code, "}}")?;
6779    writeln!(code)?;
6780
6781    // -- serve function --
6782    writeln!(
6783        code,
6784        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6785    )?;
6786    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6787    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
6788    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6789    writeln!(
6790        code,
6791        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6792    )?;
6793    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
6794    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
6795    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
6796    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
6797    writeln!(code)?;
6798    writeln!(code, "    for request in server.incoming_requests() {{")?;
6799    writeln!(code, "        let method = request.method().to_string();")?;
6800    writeln!(code, "        let url = request.url().to_string();")?;
6801    writeln!(code)?;
6802    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
6803
6804    // -- POST /v1/chat/completions --
6805    writeln!(
6806        code,
6807        "            (\"POST\", \"/v1/chat/completions\") => {{"
6808    )?;
6809    writeln!(
6810        code,
6811        "                handle_chat_completion(&mut model, &tokenizer, request);"
6812    )?;
6813    writeln!(code, "            }}")?;
6814
6815    // -- GET /v1/models --
6816    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
6817    writeln!(code, "                let body = serde_json::json!({{")?;
6818    writeln!(code, "                    \"object\": \"list\",")?;
6819    writeln!(code, "                    \"data\": [{{")?;
6820    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
6821    writeln!(code, "                        \"object\": \"model\",")?;
6822    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
6823    writeln!(code, "                    }}]")?;
6824    writeln!(code, "                }});")?;
6825    writeln!(
6826        code,
6827        "                let resp = tiny_http::Response::from_string(body.to_string())"
6828    )?;
6829    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6830    writeln!(code, "                request.respond(resp).ok();")?;
6831    writeln!(code, "            }}")?;
6832
6833    // -- GET /health --
6834    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
6835    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6836    writeln!(code, "                request.respond(resp).ok();")?;
6837    writeln!(code, "            }}")?;
6838
6839    // -- 404 --
6840    writeln!(code, "            _ => {{")?;
6841    writeln!(
6842        code,
6843        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
6844    )?;
6845    writeln!(code, "                    .with_status_code(404);")?;
6846    writeln!(code, "                request.respond(resp).ok();")?;
6847    writeln!(code, "            }}")?;
6848    writeln!(code, "        }}")?;
6849    writeln!(code, "    }}")?;
6850    writeln!(code, "}}")?;
6851    writeln!(code)?;
6852
6853    // -- handle_chat_completion --
6854    writeln!(code, "fn handle_chat_completion(")?;
6855    writeln!(code, "    model: &mut model::MetalModel,")?;
6856    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6857    writeln!(code, "    mut request: tiny_http::Request,")?;
6858    writeln!(code, ") {{")?;
6859    writeln!(code, "    // Read request body")?;
6860    writeln!(code, "    let mut body = String::new();")?;
6861    writeln!(
6862        code,
6863        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6864    )?;
6865    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6866    writeln!(code, "            .with_status_code(400);")?;
6867    writeln!(code, "        request.respond(resp).ok();")?;
6868    writeln!(code, "        return;")?;
6869    writeln!(code, "    }}")?;
6870    writeln!(code)?;
6871    writeln!(code, "    // Parse JSON")?;
6872    writeln!(
6873        code,
6874        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6875    )?;
6876    writeln!(code, "        Ok(r) => r,")?;
6877    writeln!(code, "        Err(e) => {{")?;
6878    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6879    writeln!(code, "                .with_status_code(400);")?;
6880    writeln!(code, "            request.respond(resp).ok();")?;
6881    writeln!(code, "            return;")?;
6882    writeln!(code, "        }}")?;
6883    writeln!(code, "    }};")?;
6884    writeln!(code)?;
6885    writeln!(
6886        code,
6887        "    let prompt = format_chat_messages(&req.messages);"
6888    )?;
6889    writeln!(
6890        code,
6891        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6892    )?;
6893    writeln!(code, "        Ok(e) => e,")?;
6894    writeln!(code, "        Err(e) => {{")?;
6895    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6896    writeln!(code, "                .with_status_code(500);")?;
6897    writeln!(code, "            request.respond(resp).ok();")?;
6898    writeln!(code, "            return;")?;
6899    writeln!(code, "        }}")?;
6900    writeln!(code, "    }};")?;
6901    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6902    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6903    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6904    writeln!(
6905        code,
6906        "    let _temperature = req.temperature.unwrap_or(1.0);"
6907    )?;
6908    writeln!(code)?;
6909
6910    // -- Reset KV cache for each request --
6911    writeln!(code, "    model.reset();")?;
6912    writeln!(code)?;
6913
6914    // -- Prefill with timing --
6915    writeln!(code, "    let prefill_start = Instant::now();")?;
6916    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6917    writeln!(
6918        code,
6919        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6920    )?;
6921    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6922    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6923    writeln!(code)?;
6924
6925    writeln!(code, "    if stream {{")?;
6926
6927    // -- SSE streaming response --
6928    writeln!(
6929        code,
6930        "        // SSE streaming: generate tokens and build SSE body"
6931    )?;
6932    writeln!(code, "        let gen_start = Instant::now();")?;
6933    writeln!(code, "        let mut generated_count: usize = 0;")?;
6934    writeln!(code, "        let mut sse_body = String::new();")?;
6935    writeln!(code, "        for _ in 0..max_tokens {{")?;
6936    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6937    writeln!(
6938        code,
6939        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6940    )?;
6941    writeln!(
6942        code,
6943        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6944    )?;
6945    writeln!(
6946        code,
6947        "                // escaped includes surrounding quotes, strip them"
6948    )?;
6949    writeln!(
6950        code,
6951        "                let inner = &escaped[1..escaped.len()-1];"
6952    )?;
6953    writeln!(code, "                sse_body.push_str(&format!(")?;
6954    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6955    writeln!(code, "                    inner")?;
6956    writeln!(code, "                ));")?;
6957    writeln!(code, "            }}")?;
6958    writeln!(code, "            generated_count += 1;")?;
6959    writeln!(code, "            let logits = model.forward(next_token);")?;
6960    writeln!(code, "            next_token = argmax(&logits);")?;
6961    writeln!(code, "        }}")?;
6962    writeln!(
6963        code,
6964        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6965    )?;
6966    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6967    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
6968    writeln!(code)?;
6969    writeln!(
6970        code,
6971        "        // Final chunk with finish_reason, timing, and DONE sentinel"
6972    )?;
6973    writeln!(code, "        sse_body.push_str(&format!(")?;
6974    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6975    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6976    writeln!(code, "        ));")?;
6977    writeln!(code)?;
6978    writeln!(
6979        code,
6980        "        let resp = tiny_http::Response::from_string(sse_body)"
6981    )?;
6982    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6983    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6984    writeln!(code, "        request.respond(resp).ok();")?;
6985
6986    writeln!(code, "    }} else {{")?;
6987
6988    // -- Non-streaming response --
6989    writeln!(
6990        code,
6991        "        // Non-streaming: generate all tokens, return JSON"
6992    )?;
6993    writeln!(code, "        let gen_start = Instant::now();")?;
6994    writeln!(code, "        let mut generated_count: usize = 0;")?;
6995    writeln!(code, "        let mut generated = String::new();")?;
6996    writeln!(code, "        for _ in 0..max_tokens {{")?;
6997    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6998    writeln!(
6999        code,
7000        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7001    )?;
7002    writeln!(code, "                generated.push_str(&text);")?;
7003    writeln!(code, "            }}")?;
7004    writeln!(code, "            generated_count += 1;")?;
7005    writeln!(code, "            let logits = model.forward(next_token);")?;
7006    writeln!(code, "            next_token = argmax(&logits);")?;
7007    writeln!(code, "        }}")?;
7008    writeln!(
7009        code,
7010        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7011    )?;
7012    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7013    writeln!(code)?;
7014    writeln!(code, "        let resp_json = serde_json::json!({{")?;
7015    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
7016    writeln!(code, "            \"object\": \"chat.completion\",")?;
7017    writeln!(code, "            \"choices\": [{{")?;
7018    writeln!(code, "                \"index\": 0,")?;
7019    writeln!(code, "                \"message\": {{")?;
7020    writeln!(code, "                    \"role\": \"assistant\",")?;
7021    writeln!(code, "                    \"content\": generated")?;
7022    writeln!(code, "                }},")?;
7023    writeln!(code, "                \"finish_reason\": \"stop\"")?;
7024    writeln!(code, "            }}],")?;
7025    writeln!(code, "            \"usage\": {{")?;
7026    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
7027    writeln!(
7028        code,
7029        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
7030    )?;
7031    writeln!(
7032        code,
7033        "                \"generation_tokens\": generated_count,"
7034    )?;
7035    writeln!(
7036        code,
7037        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
7038    )?;
7039    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
7040    writeln!(code, "            }}")?;
7041    writeln!(code, "        }});")?;
7042    writeln!(
7043        code,
7044        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
7045    )?;
7046    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7047    writeln!(code, "        request.respond(resp).ok();")?;
7048    writeln!(code, "    }}")?;
7049    writeln!(code, "}}")?;
7050
7051    Ok(code)
7052}
7053
7054// ---------------------------------------------------------------------------
7055// Tests
7056// ---------------------------------------------------------------------------
7057
7058#[cfg(test)]
7059mod tests {
7060    use super::*;
7061    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7062
7063    fn minimal_config() -> ModelConfig {
7064        ModelConfig {
7065            architecture: Architecture::Llama,
7066            hidden_size: 64,
7067            intermediate_size: 128,
7068            num_layers: 2,
7069            num_attention_heads: 4,
7070            num_kv_heads: 4,
7071            head_dim: 16,
7072            vocab_size: 256,
7073            max_seq_len: 512,
7074            rms_norm_eps: 1e-5,
7075            rope_theta: 10000.0,
7076            dtype: DType::F32,
7077            sliding_window_size: None,
7078            qkv_bias: false,
7079        }
7080    }
7081
7082    fn minimal_graph() -> Graph {
7083        Graph::new("test-metal").with_config(minimal_config())
7084    }
7085
7086    #[test]
7087    fn generate_metal_project_creates_files() {
7088        let dir = tempfile::tempdir().unwrap();
7089        let graph = minimal_graph();
7090        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7091
7092        assert!(
7093            dir.path().join("Cargo.toml").exists(),
7094            "Cargo.toml should be created"
7095        );
7096        assert!(
7097            dir.path().join("src/model.rs").exists(),
7098            "src/model.rs should be created"
7099        );
7100        assert!(
7101            dir.path().join("src/main.rs").exists(),
7102            "src/main.rs should be created"
7103        );
7104        assert!(
7105            dir.path().join("shaders/kernels.metal").exists(),
7106            "shaders/kernels.metal should be created"
7107        );
7108    }
7109
7110    #[test]
7111    fn generated_cargo_toml_has_metal_dep() {
7112        let toml = generate_cargo_toml("my-model");
7113        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7114        assert!(
7115            toml.contains("tokenizers"),
7116            "Cargo.toml should depend on tokenizers"
7117        );
7118        assert!(
7119            toml.contains("memmap2"),
7120            "Cargo.toml should depend on memmap2"
7121        );
7122        assert!(toml.contains("half"), "Cargo.toml should depend on half");
7123    }
7124
7125    #[test]
7126    fn generated_model_rs_contains_metal_code() {
7127        let config = minimal_config();
7128        let model_rs = generate_model_rs(&config).unwrap();
7129
7130        assert!(
7131            model_rs.contains("pub struct MetalModel"),
7132            "model.rs should define MetalModel struct"
7133        );
7134        assert!(
7135            model_rs.contains("matmul_pipeline: ComputePipelineState"),
7136            "MetalModel should have matmul_pipeline field"
7137        );
7138        assert!(
7139            model_rs.contains("Device::system_default()"),
7140            "model.rs should use Metal device"
7141        );
7142        assert!(
7143            model_rs.contains("new_library_with_source"),
7144            "model.rs should compile Metal shaders"
7145        );
7146        assert!(
7147            model_rs.contains("fn new(weights: &[u8])"),
7148            "MetalModel should implement new()"
7149        );
7150        assert!(
7151            model_rs.contains("fn forward(&mut self, token_id: u32)"),
7152            "MetalModel should implement forward()"
7153        );
7154    }
7155
7156    #[test]
7157    fn generated_shaders_contain_kernel_names() {
7158        let shaders = generate_metal_shaders(&minimal_config());
7159
7160        assert!(
7161            shaders.contains("kernel void matmul_vec"),
7162            "shaders should contain matmul_vec kernel"
7163        );
7164        assert!(
7165            shaders.contains("kernel void rms_norm"),
7166            "shaders should contain rms_norm kernel"
7167        );
7168        assert!(
7169            shaders.contains("kernel void rope"),
7170            "shaders should contain rope kernel"
7171        );
7172        assert!(
7173            shaders.contains("kernel void softmax"),
7174            "shaders should contain softmax kernel"
7175        );
7176        assert!(
7177            shaders.contains("kernel void silu_mul("),
7178            "shaders should contain silu_mul kernel"
7179        );
7180        assert!(
7181            shaders.contains("kernel void silu_mul_fused"),
7182            "shaders should contain silu_mul_fused kernel"
7183        );
7184        assert!(
7185            shaders.contains("kernel void elementwise_add"),
7186            "shaders should contain elementwise_add kernel"
7187        );
7188        assert!(
7189            shaders.contains("kernel void attention"),
7190            "shaders should contain attention kernel"
7191        );
7192        assert!(
7193            shaders.contains("kernel void add_inplace"),
7194            "shaders should contain add_inplace kernel"
7195        );
7196        assert!(
7197            shaders.contains("kernel void copy_buffer"),
7198            "shaders should contain copy_buffer kernel"
7199        );
7200        assert!(
7201            shaders.contains("kernel void copy_offset"),
7202            "shaders should contain copy_offset kernel"
7203        );
7204    }
7205
7206    #[test]
7207    fn generated_shaders_use_simdgroup_features() {
7208        let shaders = generate_metal_shaders(&minimal_config());
7209
7210        assert!(
7211            shaders.contains("threadgroup_barrier"),
7212            "shaders should use threadgroup barriers"
7213        );
7214        assert!(
7215            shaders.contains("threadgroup float"),
7216            "shaders should use threadgroup shared memory"
7217        );
7218        assert!(
7219            shaders.contains("thread_index_in_threadgroup"),
7220            "shaders should use threadgroup indexing"
7221        );
7222        assert!(
7223            shaders.contains("simd_sum"),
7224            "shaders should use simd_sum for warp-level reduction"
7225        );
7226        assert!(
7227            shaders.contains("simd_max"),
7228            "attention kernel should use simd_max for cooperative softmax"
7229        );
7230        assert!(
7231            shaders.contains("thread_index_in_simdgroup"),
7232            "shaders should use simdgroup lane indexing"
7233        );
7234        assert!(
7235            shaders.contains("simdgroup_index_in_threadgroup"),
7236            "shaders should use simdgroup indexing within threadgroup"
7237        );
7238        assert!(
7239            shaders.contains("float4"),
7240            "matmul_vec should use float4 vectorized loads"
7241        );
7242    }
7243
7244    #[test]
7245    fn generated_main_rs_has_tokenizer_usage() {
7246        let config = minimal_config();
7247        let main_rs = generate_main_rs("test-model", &config).unwrap();
7248
7249        assert!(
7250            main_rs.contains("tokenizers::Tokenizer"),
7251            "main.rs should use tokenizers crate"
7252        );
7253        assert!(
7254            main_rs.contains("MetalModel::new"),
7255            "main.rs should call MetalModel::new"
7256        );
7257        assert!(
7258            main_rs.contains("model.forward"),
7259            "main.rs should call model.forward"
7260        );
7261        assert!(
7262            main_rs.contains("memmap2"),
7263            "main.rs should use memmap2 for zero-copy weight loading"
7264        );
7265    }
7266
7267    #[test]
7268    fn missing_config_returns_error() {
7269        let dir = tempfile::tempdir().unwrap();
7270        let graph = Graph::new("no-config");
7271        let result = generate_metal_project(&graph, dir.path(), "fail");
7272        assert!(
7273            matches!(result, Err(MetalCodegenError::MissingConfig)),
7274            "should fail with MissingConfig when graph has no config"
7275        );
7276    }
7277
7278    #[test]
7279    fn sanitize_name_works() {
7280        assert_eq!(sanitize_name("My Model!"), "my-model");
7281        assert_eq!(sanitize_name("test_model"), "test-model");
7282        assert_eq!(sanitize_name("simple"), "simple");
7283    }
7284
7285    #[test]
7286    fn generated_forward_uses_single_command_buffer() {
7287        let config = minimal_config();
7288        let model_rs = generate_model_rs(&config).unwrap();
7289
7290        // The forward function should create exactly one command buffer.
7291        // Use the exact signature to avoid matching forward_prefill/forward_profile.
7292        let forward_start = model_rs
7293            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7294            .unwrap();
7295        let forward_body = &model_rs[forward_start..];
7296        // End at the next pub/private method
7297        let forward_end = forward_body
7298            .find("\n    pub fn forward_profile")
7299            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7300            .or_else(|| forward_body.find("\n    fn dispatch_"))
7301            .unwrap_or(forward_body.len());
7302        let forward_code = &forward_body[..forward_end];
7303
7304        // Should have exactly one new_command_buffer call
7305        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7306        assert_eq!(
7307            cmd_buf_count, 1,
7308            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7309        );
7310
7311        // Should have exactly one commit call
7312        let commit_count = forward_code.matches("cmd.commit()").count();
7313        assert_eq!(
7314            commit_count, 1,
7315            "forward() should commit exactly once, found {commit_count}"
7316        );
7317
7318        // Should wait: once for cmd + possibly once for prev_cmd drain
7319        let wait_count = forward_code.matches("wait_until_completed()").count();
7320        assert!(
7321            wait_count >= 1 && wait_count <= 2,
7322            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7323        );
7324    }
7325
7326    #[test]
7327    fn generated_model_has_preallocated_working_buffers() {
7328        let config = minimal_config();
7329        let model_rs = generate_model_rs(&config).unwrap();
7330
7331        for buf_name in &[
7332            "normed_buf",
7333            "qkv_buf",
7334            "attn_out_buf",
7335            "attn_proj_buf",
7336            "gate_up_buf",
7337            "ffn_hidden_buf",
7338            "ffn_out_buf",
7339            "add_tmp_buf",
7340        ] {
7341            assert!(
7342                model_rs.contains(&format!("{buf_name}: Buffer")),
7343                "MetalModel should have pre-allocated {buf_name} field"
7344            );
7345        }
7346    }
7347
7348    #[test]
7349    fn generated_dispatch_helpers_take_compute_encoder_ref() {
7350        let config = minimal_config();
7351        let model_rs = generate_model_rs(&config).unwrap();
7352
7353        for method in &[
7354            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7355            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7356            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7357            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7358            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7359            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7360            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7361            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7362            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7363            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7364            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7365            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7366            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7367        ] {
7368            assert!(
7369                model_rs.contains(method),
7370                "model.rs should contain dispatch helper: {method}"
7371            );
7372        }
7373    }
7374
7375    #[test]
7376    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7377        let config = minimal_config();
7378        let model_rs = generate_model_rs(&config).unwrap();
7379
7380        // Find dispatch helpers section and check none create their own encoders
7381        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7382        let helpers_code = &model_rs[helpers_start..];
7383
7384        // None of the dispatch_ helpers should call new_command_buffer
7385        assert!(
7386            !helpers_code.contains("self.queue.new_command_buffer()"),
7387            "dispatch helpers should not create their own command buffers"
7388        );
7389
7390        // None should create their own compute encoders
7391        assert!(
7392            !helpers_code.contains("new_compute_command_encoder()"),
7393            "dispatch helpers should not create their own compute encoders"
7394        );
7395
7396        // None should call end_encoding
7397        assert!(
7398            !helpers_code.contains("end_encoding()"),
7399            "dispatch helpers should not call end_encoding"
7400        );
7401
7402        // None should call commit or wait
7403        assert!(
7404            !helpers_code.contains(".commit()"),
7405            "dispatch helpers should not commit command buffers"
7406        );
7407        assert!(
7408            !helpers_code.contains("wait_until_completed"),
7409            "dispatch helpers should not wait on command buffers"
7410        );
7411    }
7412
7413    #[test]
7414    fn generated_forward_batches_compute_encoders() {
7415        let config = minimal_config();
7416        let model_rs = generate_model_rs(&config).unwrap();
7417
7418        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
7419        let forward_start = model_rs
7420            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7421            .unwrap();
7422        let forward_body = &model_rs[forward_start..];
7423        let forward_end = forward_body
7424            .find("\n    pub fn forward_profile")
7425            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7426            .or_else(|| forward_body.find("\n    fn dispatch_"))
7427            .unwrap_or(forward_body.len());
7428        let forward_code = &forward_body[..forward_end];
7429
7430        // Forward should not allocate new buffers
7431        assert!(
7432            !forward_code.contains("device.new_buffer"),
7433            "forward() should not allocate new buffers per call"
7434        );
7435
7436        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
7437        // Copy operations use compute copy kernels instead of blit encoders.
7438        let compute_encoder_count = forward_code
7439            .matches("new_compute_command_encoder()")
7440            .count();
7441        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7442
7443        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
7444        assert_eq!(
7445            compute_encoder_count, 1,
7446            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7447        );
7448        assert_eq!(
7449            blit_encoder_count, 0,
7450            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7451        );
7452    }
7453
7454    #[test]
7455    fn generated_forward_uses_add_inplace() {
7456        let config = minimal_config();
7457        let model_rs = generate_model_rs(&config).unwrap();
7458
7459        // Should use in-place add (no blit copy-back needed)
7460        assert!(
7461            model_rs.contains("dispatch_add_inplace"),
7462            "forward() should use dispatch_add_inplace for residual connections"
7463        );
7464        assert!(
7465            model_rs.contains("add_inplace_pipeline"),
7466            "MetalModel should have add_inplace_pipeline"
7467        );
7468    }
7469
7470    fn minimal_q8_config() -> ModelConfig {
7471        ModelConfig {
7472            architecture: Architecture::Llama,
7473            hidden_size: 64,
7474            intermediate_size: 128,
7475            num_layers: 2,
7476            num_attention_heads: 4,
7477            num_kv_heads: 4,
7478            head_dim: 16,
7479            vocab_size: 256,
7480            max_seq_len: 512,
7481            rms_norm_eps: 1e-5,
7482            rope_theta: 10000.0,
7483            dtype: DType::Q8_0,
7484            sliding_window_size: None,
7485            qkv_bias: false,
7486        }
7487    }
7488
7489    #[test]
7490    fn generated_shaders_contain_q8_kernel() {
7491        let shaders = generate_metal_shaders(&minimal_config());
7492
7493        assert!(
7494            shaders.contains("kernel void matmul_vec_q8"),
7495            "shaders should contain matmul_vec_q8 kernel"
7496        );
7497        assert!(
7498            shaders.contains("device const uchar* matrix"),
7499            "matmul_vec_q8 should accept raw Q8_0 bytes"
7500        );
7501        assert!(
7502            shaders.contains("packed_short4"),
7503            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7504        );
7505        assert!(
7506            shaders.contains("as_type<char2>"),
7507            "matmul_vec_q8 should bitcast short lanes to char2"
7508        );
7509        assert!(
7510            shaders.contains("device const half*"),
7511            "matmul_vec_q8 should read f16 scale via half pointer"
7512        );
7513    }
7514
7515    #[test]
7516    fn generated_model_uses_fused_qkv_projections() {
7517        let config = minimal_config();
7518        let model_rs = generate_model_rs(&config).unwrap();
7519
7520        // Should have fused QKV weight in layer buffers
7521        assert!(
7522            model_rs.contains("qkv_weight: Buffer"),
7523            "LayerBuffers should have fused qkv_weight field"
7524        );
7525        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
7526        assert!(
7527            !model_rs.contains("    q_weight: Buffer"),
7528            "LayerBuffers should not have separate q_weight field"
7529        );
7530        assert!(
7531            !model_rs.contains("    k_weight: Buffer"),
7532            "LayerBuffers should not have separate k_weight field"
7533        );
7534        assert!(
7535            !model_rs.contains("    v_weight: Buffer"),
7536            "LayerBuffers should not have separate v_weight field"
7537        );
7538
7539        // Should have fused gate_up_weight
7540        assert!(
7541            model_rs.contains("gate_up_weight: Buffer"),
7542            "LayerBuffers should have fused gate_up_weight field"
7543        );
7544        // Should NOT have separate gate/up weight fields
7545        assert!(
7546            !model_rs.contains("    gate_weight: Buffer"),
7547            "LayerBuffers should not have separate gate_weight field"
7548        );
7549        assert!(
7550            !model_rs.contains("    up_weight: Buffer"),
7551            "LayerBuffers should not have separate up_weight field"
7552        );
7553
7554        // Should have fused working buffers
7555        assert!(
7556            model_rs.contains("qkv_buf: Buffer"),
7557            "MetalModel should have fused qkv_buf"
7558        );
7559        assert!(
7560            model_rs.contains("gate_up_buf: Buffer"),
7561            "MetalModel should have fused gate_up_buf"
7562        );
7563
7564        // Forward pass should use fused dispatch
7565        assert!(
7566            model_rs.contains("dispatch_silu_mul_fused"),
7567            "forward pass should use dispatch_silu_mul_fused"
7568        );
7569        assert!(
7570            model_rs.contains("dispatch_rope_offset"),
7571            "forward pass should use dispatch_rope_offset for fused QKV"
7572        );
7573        assert!(
7574            model_rs.contains("dispatch_attention_offset"),
7575            "forward pass should use dispatch_attention_offset for fused QKV"
7576        );
7577    }
7578
7579    #[test]
7580    fn q8_model_has_matmul_q8_pipeline() {
7581        let config = minimal_q8_config();
7582        let model_rs = generate_model_rs(&config).unwrap();
7583
7584        assert!(
7585            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7586            "MetalModel should have matmul_q8_pipeline field"
7587        );
7588        assert!(
7589            model_rs.contains("matmul_q8_pipeline,"),
7590            "MetalModel Self should include matmul_q8_pipeline"
7591        );
7592    }
7593
7594    #[test]
7595    fn q8_model_uses_dispatch_matmul_q8() {
7596        let config = minimal_q8_config();
7597        let model_rs = generate_model_rs(&config).unwrap();
7598
7599        assert!(
7600            model_rs.contains("dispatch_matmul_q8"),
7601            "Q8_0 model should use dispatch_matmul_q8 for projections"
7602        );
7603        assert!(
7604            model_rs.contains("fn dispatch_matmul_q8"),
7605            "model.rs should define dispatch_matmul_q8 method"
7606        );
7607    }
7608
7609    #[test]
7610    fn q8_model_loads_raw_bytes_not_dequantized() {
7611        let config = minimal_q8_config();
7612        let model_rs = generate_model_rs(&config).unwrap();
7613
7614        // Should NOT contain dequantization code
7615        assert!(
7616            !model_rs.contains("f16_to_f32"),
7617            "Q8_0 model should not dequantize weights to f32"
7618        );
7619        assert!(
7620            !model_rs.contains("f32_data"),
7621            "Q8_0 model should not create f32 weight data"
7622        );
7623
7624        // Should load raw Q8_0 bytes directly
7625        assert!(
7626            model_rs.contains("total_raw as u64"),
7627            "Q8_0 model should load raw bytes into Metal buffer"
7628        );
7629    }
7630
7631    #[test]
7632    fn q8_model_norms_stay_f32() {
7633        let config = minimal_q8_config();
7634        let model_rs = generate_model_rs(&config).unwrap();
7635
7636        // Norm weights should still use f32 buffers
7637        assert!(
7638            model_rs.contains("let attn_norm = next_f32_buffer"),
7639            "attn_norm should use f32 buffer even for Q8_0 models"
7640        );
7641        assert!(
7642            model_rs.contains("let ffn_norm = next_f32_buffer"),
7643            "ffn_norm should use f32 buffer even for Q8_0 models"
7644        );
7645        assert!(
7646            model_rs.contains("let norm_buf = next_f32_buffer"),
7647            "final norm should use f32 buffer even for Q8_0 models"
7648        );
7649    }
7650
7651    #[test]
7652    fn q8_model_uses_fused_weight_loading() {
7653        let config = minimal_q8_config();
7654        let model_rs = generate_model_rs(&config).unwrap();
7655
7656        // Should use fused Q8 buffer loading for QKV
7657        assert!(
7658            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7659            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7660        );
7661        // Should use fused Q8 buffer loading for gate+up
7662        assert!(
7663            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7664            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7665        );
7666        // Should still use regular q8 buffer for individual weights
7667        assert!(
7668            model_rs.contains("let o_weight = next_q8_buffer"),
7669            "Q8_0 model should use next_q8_buffer for O weight"
7670        );
7671        assert!(
7672            model_rs.contains("let down_weight = next_q8_buffer"),
7673            "Q8_0 model should use next_q8_buffer for down weight"
7674        );
7675    }
7676
7677    #[test]
7678    fn f32_model_does_not_use_q8_dispatch() {
7679        let config = minimal_config();
7680        let model_rs = generate_model_rs(&config).unwrap();
7681
7682        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
7683        let forward_start = model_rs
7684            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7685            .unwrap();
7686        let forward_body = &model_rs[forward_start..];
7687        let forward_end = forward_body
7688            .find("\n    fn dispatch_")
7689            .unwrap_or(forward_body.len());
7690        let forward_code = &forward_body[..forward_end];
7691
7692        assert!(
7693            !forward_code.contains("dispatch_matmul_q8"),
7694            "f32 model forward should not use dispatch_matmul_q8"
7695        );
7696    }
7697
7698    #[test]
7699    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7700        let config = minimal_q8_config();
7701        let model_rs = generate_model_rs(&config).unwrap();
7702
7703        assert!(
7704            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7705            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7706        );
7707    }
7708
7709    #[test]
7710    fn generated_model_has_double_buffered_prefill() {
7711        let config = minimal_config();
7712        let model_rs = generate_model_rs(&config).unwrap();
7713
7714        // MetalModel should have prev_cmd field for double-buffered prefill
7715        assert!(
7716            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7717            "MetalModel should have prev_cmd field for double-buffered prefill"
7718        );
7719
7720        // Should have forward_prefill method
7721        assert!(
7722            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7723            "MetalModel should have forward_prefill method"
7724        );
7725
7726        // forward() should drain prev_cmd at the start
7727        assert!(
7728            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7729            "forward() should drain prev_cmd from previous prefill"
7730        );
7731    }
7732
7733    #[test]
7734    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7735        let config = minimal_config();
7736        let main_rs = generate_main_rs("test-model", &config).unwrap();
7737
7738        assert!(
7739            main_rs.contains("forward_prefill"),
7740            "main.rs should use forward_prefill for intermediate prompt tokens"
7741        );
7742        assert!(
7743            main_rs.contains("double-buffered"),
7744            "main.rs should document double-buffered prefill"
7745        );
7746    }
7747
7748    #[test]
7749    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7750        let shaders = generate_metal_shaders(&minimal_config());
7751
7752        assert!(
7753            shaders.contains("packed_short4"),
7754            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7755        );
7756        assert!(
7757            shaders.contains("d0[0]"),
7758            "matmul_vec_q8 should index the wide pointer for row 0"
7759        );
7760        assert!(
7761            shaders.contains("as_type<char2>"),
7762            "matmul_vec_q8 should bitcast short lanes to char2"
7763        );
7764        assert!(
7765            shaders.contains("dot("),
7766            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7767        );
7768    }
7769
7770    // ── Q4_0 tests ──────────────────────────────────────────────────────
7771
7772    fn minimal_q4_config() -> ModelConfig {
7773        ModelConfig {
7774            architecture: Architecture::Llama,
7775            hidden_size: 64,
7776            intermediate_size: 128,
7777            num_layers: 2,
7778            num_attention_heads: 4,
7779            num_kv_heads: 4,
7780            head_dim: 16,
7781            vocab_size: 256,
7782            max_seq_len: 512,
7783            rms_norm_eps: 1e-5,
7784            rope_theta: 10000.0,
7785            dtype: DType::Q4_0,
7786            sliding_window_size: None,
7787            qkv_bias: false,
7788        }
7789    }
7790
7791    #[test]
7792    fn generated_shaders_contain_q4_kernel() {
7793        let shaders = generate_metal_shaders(&minimal_config());
7794
7795        assert!(
7796            shaders.contains("kernel void matmul_vec_q4"),
7797            "shaders should contain matmul_vec_q4 kernel"
7798        );
7799        assert!(
7800            shaders.contains("Q4_ROWS_PER_TG"),
7801            "shaders should define Q4_ROWS_PER_TG constant"
7802        );
7803        assert!(
7804            shaders.contains("Q4_ROWS_PER_SG"),
7805            "shaders should define Q4_ROWS_PER_SG constant"
7806        );
7807    }
7808
7809    #[test]
7810    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7811        let shaders = generate_metal_shaders(&minimal_config());
7812
7813        // Q4_0 kernel should use uchar4 for packed byte loads
7814        assert!(
7815            shaders.contains("uchar4"),
7816            "matmul_vec_q4 should use uchar4 for packed byte loads"
7817        );
7818        // Should unpack nibbles with &0xF and >>4
7819        assert!(
7820            shaders.contains("&0xF"),
7821            "matmul_vec_q4 should extract low nibble with &0xF"
7822        );
7823        assert!(
7824            shaders.contains(">>4"),
7825            "matmul_vec_q4 should extract high nibble with >>4"
7826        );
7827        // Should subtract 8 to convert unsigned to signed
7828        assert!(
7829            shaders.contains("-8)"),
7830            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7831        );
7832        // Should use 18-byte block size
7833        assert!(
7834            shaders.contains("blk * 18"),
7835            "matmul_vec_q4 should use 18-byte block stride"
7836        );
7837    }
7838
7839    #[test]
7840    fn q4_model_has_matmul_q4_pipeline() {
7841        let config = minimal_q4_config();
7842        let model_rs = generate_model_rs(&config).unwrap();
7843
7844        assert!(
7845            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7846            "MetalModel should have matmul_q4_pipeline field"
7847        );
7848        assert!(
7849            model_rs.contains("matmul_q4_pipeline,"),
7850            "MetalModel Self should include matmul_q4_pipeline"
7851        );
7852    }
7853
7854    #[test]
7855    fn q4_model_uses_dispatch_matmul_q4() {
7856        let config = minimal_q4_config();
7857        let model_rs = generate_model_rs(&config).unwrap();
7858
7859        assert!(
7860            model_rs.contains("dispatch_matmul_q4"),
7861            "Q4_0 model should use dispatch_matmul_q4 for projections"
7862        );
7863        assert!(
7864            model_rs.contains("fn dispatch_matmul_q4"),
7865            "model.rs should define dispatch_matmul_q4 method"
7866        );
7867    }
7868
7869    #[test]
7870    fn q4_model_loads_raw_bytes_not_dequantized() {
7871        let config = minimal_q4_config();
7872        let model_rs = generate_model_rs(&config).unwrap();
7873
7874        // Should NOT contain dequantization code
7875        assert!(
7876            !model_rs.contains("f16_to_f32"),
7877            "Q4_0 model should not dequantize weights to f32"
7878        );
7879
7880        // Should load raw Q4_0 bytes directly
7881        assert!(
7882            model_rs.contains("total_raw as u64"),
7883            "Q4_0 model should load raw bytes into Metal buffer"
7884        );
7885    }
7886
7887    #[test]
7888    fn q4_model_norms_stay_f32() {
7889        let config = minimal_q4_config();
7890        let model_rs = generate_model_rs(&config).unwrap();
7891
7892        assert!(
7893            model_rs.contains("let attn_norm = next_f32_buffer"),
7894            "attn_norm should use f32 buffer even for Q4_0 models"
7895        );
7896        assert!(
7897            model_rs.contains("let ffn_norm = next_f32_buffer"),
7898            "ffn_norm should use f32 buffer even for Q4_0 models"
7899        );
7900        assert!(
7901            model_rs.contains("let norm_buf = next_f32_buffer"),
7902            "final norm should use f32 buffer even for Q4_0 models"
7903        );
7904    }
7905
7906    #[test]
7907    fn q4_model_uses_fused_weight_loading() {
7908        let config = minimal_q4_config();
7909        let model_rs = generate_model_rs(&config).unwrap();
7910
7911        assert!(
7912            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7913            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7914        );
7915        assert!(
7916            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7917            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7918        );
7919        assert!(
7920            model_rs.contains("let o_weight = next_q4_buffer"),
7921            "Q4_0 model should use next_q4_buffer for O weight"
7922        );
7923        assert!(
7924            model_rs.contains("let down_weight = next_q4_buffer"),
7925            "Q4_0 model should use next_q4_buffer for down weight"
7926        );
7927    }
7928
7929    #[test]
7930    fn attention_flash_batch_kernel_exists() {
7931        // The flash kernel is still wired into the library (pipeline, kernel
7932        // source).  Dispatch is currently routed to the legacy path pending a
7933        // fix for a numerical issue discovered after the prompt-chunking fix.
7934        let config = minimal_config();
7935        let model_rs = generate_model_rs(&config).unwrap();
7936        let shaders = generate_metal_shaders(&config);
7937
7938        assert!(
7939            shaders.contains("kernel void attention_flash_batch"),
7940            "shaders.metal must still contain the attention_flash_batch kernel"
7941        );
7942        assert!(
7943            shaders.contains("FLASH_K_TILE"),
7944            "flash kernel must tile K/V with a FLASH_K_TILE constant"
7945        );
7946        assert!(
7947            model_rs.contains("attention_flash_batch_pipeline"),
7948            "MetalModel must register the flash attention pipeline"
7949        );
7950    }
7951
7952    #[test]
7953    fn decode_uses_fused_rope_and_kv_copy() {
7954        // v0.7.2: single-token decode reuses `rope_qk_batch` (M=1) and
7955        // `copy_kv_both_batch` (M=1) instead of calling the separate
7956        // rope_offset + copy_from_offset_f16 pairs, saving 2 dispatches +
7957        // barriers per layer.
7958        let config = minimal_config();
7959        let model_rs = generate_model_rs(&config).unwrap();
7960
7961        // forward() and forward_profile() each have a decode loop.
7962        // Locate `pub fn forward(` and `pub fn forward_profile(` and verify
7963        // the fused dispatches appear between the function header and the
7964        // start of forward_prefill (which also uses these fused helpers).
7965        let forward_start = model_rs.find("pub fn forward(").expect("forward missing");
7966        let forward_end = model_rs[forward_start..]
7967            .find("pub fn forward_prefill(")
7968            .expect("forward_prefill missing");
7969        let forward_body = &model_rs[forward_start..forward_start + forward_end];
7970
7971        assert!(
7972            forward_body.contains("dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1"),
7973            "single-token forward must use fused rope_qk_batch with M=1"
7974        );
7975        assert!(
7976            forward_body.contains("dispatch_copy_kv_both_batch(&enc, &self.qkv_buf,"),
7977            "single-token forward must use fused copy_kv_both_batch"
7978        );
7979        assert!(
7980            !forward_body.contains("dispatch_copy_from_offset_f16"),
7981            "decode path should no longer call the per-K/per-V copy"
7982        );
7983    }
7984
7985    #[test]
7986    fn kv_cache_stored_as_f16() {
7987        // v0.7.1: KV cache is f16 (2 bytes/element) instead of f32 to halve
7988        // attention memory bandwidth and KV RAM footprint.  All attention
7989        // kernels must read K/V as `device const half*`; copy kernels must
7990        // write half; the f32->f16 copy kernel must be wired for single-
7991        // token decode KV writes.
7992        let config = minimal_config();
7993        let shaders = generate_metal_shaders(&config);
7994        let model_rs = generate_model_rs(&config).unwrap();
7995
7996        for kernel in [
7997            "kernel void attention(",
7998            "kernel void attention_batch(",
7999            "kernel void attention_flash_batch(",
8000            "kernel void attention_mma_flash_batch(",
8001        ] {
8002            let start = shaders
8003                .find(kernel)
8004                .unwrap_or_else(|| panic!("kernel {kernel} missing"));
8005            // Slice through the signature block — scan for "){" which closes
8006            // the parameter list at the top of every kernel's body.
8007            let sig_end = shaders[start..]
8008                .find("){")
8009                .unwrap_or_else(|| shaders[start..].find(") {").unwrap());
8010            let sig = &shaders[start..start + sig_end];
8011            assert!(
8012                sig.contains("device const half*")
8013                    && sig.contains("k_cache")
8014                    && sig.contains("v_cache"),
8015                "{kernel} must read k_cache/v_cache as `device const half*`"
8016            );
8017            assert!(
8018                !sig.contains("device const float* k_cache")
8019                    && !sig.contains("device const float*  k_cache"),
8020                "{kernel} still reads k_cache as float"
8021            );
8022        }
8023
8024        assert!(
8025            shaders.contains("kernel void copy_f32_to_f16_offset"),
8026            "f32->f16 copy kernel must be present for single-token decode KV writes"
8027        );
8028        assert!(
8029            model_rs.contains("dispatch_copy_from_offset_f16"),
8030            "single-token decode must dispatch the f32->f16 KV copy"
8031        );
8032    }
8033
8034    #[test]
8035    fn attention_mma_flash_batch_kernel_wired() {
8036        // MMA-accelerated flash attention (issue #212).  Default-on in v0.7.0
8037        // when HEAD_DIM ≤ 128 and num_tokens ≥ 8.  FORGE_MMA_ATTN=0 opts out.
8038        let config = minimal_config();
8039        let model_rs = generate_model_rs(&config).unwrap();
8040        let shaders = generate_metal_shaders(&config);
8041
8042        assert!(
8043            shaders.contains("kernel void attention_mma_flash_batch"),
8044            "shaders.metal must contain the MMA flash kernel"
8045        );
8046        assert!(
8047            shaders.contains("FLASH_MMA_Q_BLOCK"),
8048            "MMA flash kernel must define Q_BLOCK tiling constant"
8049        );
8050        assert!(
8051            shaders.contains("simdgroup_multiply_accumulate"),
8052            "MMA flash kernel must use hardware MMA"
8053        );
8054        assert!(
8055            model_rs.contains("attention_mma_flash_batch_pipeline"),
8056            "MetalModel must register the MMA flash pipeline"
8057        );
8058        assert!(
8059            model_rs.contains("mma_opt_out"),
8060            "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
8061        );
8062        assert!(
8063            model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
8064            "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
8065        );
8066    }
8067
8068    #[test]
8069    fn forward_prefill_batch_chunks_by_max_batch_size() {
8070        // Regression: prior to v0.6.4 forward_prefill_batch truncated prompts
8071        // longer than MAX_BATCH_SIZE (512) tokens via `.min(MAX_BATCH_SIZE)`,
8072        // silently dropping the middle of long prompts.  Must now loop over
8073        // MAX_BATCH_SIZE-sized chunks and carry KV-cache state across them.
8074        let config = minimal_config();
8075        let model_rs = generate_model_rs(&config).unwrap();
8076        assert!(
8077            model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
8078            "forward_prefill_batch must chunk long prompts"
8079        );
8080        assert!(
8081            !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
8082            "the old truncation path must be gone"
8083        );
8084    }
8085
8086    #[test]
8087    fn qwen2_qkv_bias_wired_through_metal_codegen() {
8088        // Issue #210: the pre-v0.6.2 Metal codegen had zero handling for
8089        // qkv_bias.  Verify that a Qwen2-style config emits the bias buffer,
8090        // loader, pipeline, and dispatch call in the expected places.
8091        let config = ModelConfig {
8092            architecture: Architecture::Qwen2,
8093            qkv_bias: true,
8094            ..minimal_config()
8095        };
8096        let model_rs = generate_model_rs(&config).unwrap();
8097
8098        assert!(
8099            model_rs.contains("qkv_bias: Buffer"),
8100            "Qwen2 LayerBuffers must declare qkv_bias field"
8101        );
8102        assert!(
8103            model_rs.contains("let qkv_bias = next_f32_buffer"),
8104            "Qwen2 layer init must load the bias from the weight blob"
8105        );
8106        assert!(
8107            model_rs.contains("add_bias_batch_pipeline"),
8108            "Qwen2 model struct must include the add_bias_batch_pipeline"
8109        );
8110        assert!(
8111            model_rs.contains("fn dispatch_add_bias_batch"),
8112            "Qwen2 codegen must emit dispatch_add_bias_batch helper"
8113        );
8114        assert!(
8115            model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
8116            "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
8117        );
8118        assert!(
8119            model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
8120            "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
8121        );
8122
8123        // The add_bias_batch MSL kernel must be in the shader source.
8124        let shaders = generate_metal_shaders(&config);
8125        assert!(
8126            shaders.contains("kernel void add_bias_batch"),
8127            "shaders.metal must contain the add_bias_batch kernel"
8128        );
8129    }
8130
8131    #[test]
8132    fn llama_does_not_emit_qkv_bias_machinery() {
8133        // Negative test: non-Qwen2 models must NOT carry the bias dispatch,
8134        // buffer, or pipeline — keeps generated code lean for Llama/Phi/etc.
8135        let config = minimal_config();
8136        assert!(!config.qkv_bias);
8137        let model_rs = generate_model_rs(&config).unwrap();
8138        assert!(
8139            !model_rs.contains("qkv_bias: Buffer"),
8140            "Llama must not have qkv_bias field"
8141        );
8142        assert!(
8143            !model_rs.contains("add_bias_batch_pipeline"),
8144            "Llama must not pull in add_bias_batch_pipeline"
8145        );
8146        assert!(
8147            !model_rs.contains("dispatch_add_bias_batch"),
8148            "Llama must not call dispatch_add_bias_batch"
8149        );
8150    }
8151
8152    #[test]
8153    fn q4_dispatch_helper_takes_compute_encoder_ref() {
8154        let config = minimal_q4_config();
8155        let model_rs = generate_model_rs(&config).unwrap();
8156
8157        assert!(
8158            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8159            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8160        );
8161    }
8162
8163    #[test]
8164    fn f32_model_does_not_use_q4_dispatch() {
8165        let config = minimal_config();
8166        let model_rs = generate_model_rs(&config).unwrap();
8167
8168        let forward_start = model_rs
8169            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8170            .unwrap();
8171        let forward_body = &model_rs[forward_start..];
8172        let forward_end = forward_body
8173            .find("\n    fn dispatch_")
8174            .unwrap_or(forward_body.len());
8175        let forward_code = &forward_body[..forward_end];
8176
8177        assert!(
8178            !forward_code.contains("dispatch_matmul_q4"),
8179            "f32 model forward should not use dispatch_matmul_q4"
8180        );
8181    }
8182
8183    #[test]
8184    fn q4_model_lm_head_uses_q4_buffer() {
8185        let config = minimal_q4_config();
8186        let model_rs = generate_model_rs(&config).unwrap();
8187
8188        assert!(
8189            model_rs.contains("let lm_head_buf = next_q4_buffer"),
8190            "Q4_0 model should use next_q4_buffer for lm_head"
8191        );
8192    }
8193
8194    #[test]
8195    fn vec_tile_size_matches_model_dimensions() {
8196        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
8197        let small = minimal_config();
8198        let shaders_small = generate_metal_shaders(&small);
8199        assert!(
8200            shaders_small.contains("vec_tile[128]"),
8201            "vec_tile should be sized to max(hidden, intermediate) = 128"
8202        );
8203
8204        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
8205        let mut large = minimal_config();
8206        large.hidden_size = 2048;
8207        large.intermediate_size = 8192;
8208        let shaders_large = generate_metal_shaders(&large);
8209        assert!(
8210            shaders_large.contains("vec_tile[8192]"),
8211            "vec_tile should be 8192 for models with intermediate=8192"
8212        );
8213        assert!(
8214            !shaders_large.contains("vec_tile[4096]"),
8215            "vec_tile should NOT be hardcoded to 4096"
8216        );
8217    }
8218
8219    #[test]
8220    fn generated_cargo_toml_has_server_deps() {
8221        let toml = generate_cargo_toml("my-model");
8222        assert!(
8223            toml.contains("tiny_http"),
8224            "Cargo.toml should depend on tiny_http"
8225        );
8226        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8227        assert!(
8228            toml.contains("serde_json"),
8229            "Cargo.toml should depend on serde_json"
8230        );
8231    }
8232
8233    #[test]
8234    fn generated_main_rs_has_serve_mode() {
8235        let config = minimal_config();
8236        let main_rs = generate_main_rs("test-model", &config).unwrap();
8237
8238        assert!(
8239            main_rs.contains("--serve"),
8240            "main.rs should parse --serve flag"
8241        );
8242        assert!(
8243            main_rs.contains("--port"),
8244            "main.rs should parse --port flag"
8245        );
8246        assert!(
8247            main_rs.contains("fn serve("),
8248            "main.rs should define serve function"
8249        );
8250        assert!(
8251            main_rs.contains("tiny_http::Server::http"),
8252            "main.rs should create tiny_http server"
8253        );
8254    }
8255
8256    #[test]
8257    fn generated_main_rs_has_chat_completions_endpoint() {
8258        let config = minimal_config();
8259        let main_rs = generate_main_rs("test-model", &config).unwrap();
8260
8261        assert!(
8262            main_rs.contains("/v1/chat/completions"),
8263            "main.rs should handle /v1/chat/completions endpoint"
8264        );
8265        assert!(
8266            main_rs.contains("/v1/models"),
8267            "main.rs should handle /v1/models endpoint"
8268        );
8269        assert!(
8270            main_rs.contains("/health"),
8271            "main.rs should handle /health endpoint"
8272        );
8273    }
8274
8275    #[test]
8276    fn generated_main_rs_has_sse_streaming() {
8277        let config = minimal_config();
8278        let main_rs = generate_main_rs("test-model", &config).unwrap();
8279
8280        assert!(
8281            main_rs.contains("text/event-stream"),
8282            "main.rs should set SSE content type for streaming"
8283        );
8284        assert!(
8285            main_rs.contains("chat.completion.chunk"),
8286            "main.rs should emit SSE chunks"
8287        );
8288        assert!(
8289            main_rs.contains("[DONE]"),
8290            "main.rs should emit [DONE] sentinel"
8291        );
8292    }
8293
8294    #[test]
8295    fn generated_main_rs_has_chat_message_formatting() {
8296        let config = minimal_config();
8297        let main_rs = generate_main_rs("test-model", &config).unwrap();
8298
8299        assert!(
8300            main_rs.contains("fn format_chat_messages"),
8301            "main.rs should define format_chat_messages function"
8302        );
8303        assert!(
8304            main_rs.contains("<|im_start|>"),
8305            "main.rs should use ChatML format"
8306        );
8307        assert!(
8308            main_rs.contains("<|im_end|>"),
8309            "main.rs should use ChatML format"
8310        );
8311    }
8312
8313    #[test]
8314    fn generated_main_rs_has_request_types() {
8315        let config = minimal_config();
8316        let main_rs = generate_main_rs("test-model", &config).unwrap();
8317
8318        assert!(
8319            main_rs.contains("struct ChatRequest"),
8320            "main.rs should define ChatRequest struct"
8321        );
8322        assert!(
8323            main_rs.contains("struct ChatMessage"),
8324            "main.rs should define ChatMessage struct"
8325        );
8326        assert!(
8327            main_rs.contains("Deserialize"),
8328            "main.rs should derive Deserialize for request types"
8329        );
8330    }
8331
8332    #[test]
8333    fn generated_model_has_reset_method() {
8334        let config = minimal_config();
8335        let model_rs = generate_model_rs(&config).unwrap();
8336
8337        assert!(
8338            model_rs.contains("pub fn reset(&mut self)"),
8339            "model.rs should have a reset() method for multi-request serving"
8340        );
8341        assert!(
8342            model_rs.contains("self.pos = 0"),
8343            "reset() should reset position to 0"
8344        );
8345    }
8346
8347    #[test]
8348    fn generated_main_rs_cli_mode_still_works() {
8349        let config = minimal_config();
8350        let main_rs = generate_main_rs("test-model", &config).unwrap();
8351
8352        // CLI mode should still be functional
8353        assert!(
8354            main_rs.contains("fn cli_mode("),
8355            "main.rs should define cli_mode function"
8356        );
8357        assert!(
8358            main_rs.contains("model.forward"),
8359            "main.rs should call model.forward"
8360        );
8361        assert!(
8362            main_rs.contains("model.forward_prefill"),
8363            "main.rs should call model.forward_prefill"
8364        );
8365    }
8366
8367    // ── Batched prefill tests ──────────────────────────────────────────
8368
8369    #[test]
8370    fn generated_shaders_contain_batch_kernels() {
8371        let shaders = generate_metal_shaders(&minimal_config());
8372
8373        assert!(
8374            shaders.contains("kernel void matmul_vec_batch"),
8375            "shaders should contain matmul_vec_batch kernel"
8376        );
8377        assert!(
8378            shaders.contains("kernel void matmul_vec_q8_batch"),
8379            "shaders should contain matmul_vec_q8_batch kernel"
8380        );
8381        assert!(
8382            shaders.contains("kernel void matmul_q8_gemm_batch"),
8383            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8384        );
8385        assert!(
8386            shaders.contains("kernel void matmul_vec_q4_batch"),
8387            "shaders should contain matmul_vec_q4_batch kernel"
8388        );
8389        assert!(
8390            shaders.contains("kernel void rms_norm_batch"),
8391            "shaders should contain rms_norm_batch kernel"
8392        );
8393        assert!(
8394            shaders.contains("kernel void silu_mul_fused_batch"),
8395            "shaders should contain silu_mul_fused_batch kernel"
8396        );
8397        assert!(
8398            shaders.contains("kernel void add_inplace_batch"),
8399            "shaders should contain add_inplace_batch kernel"
8400        );
8401        assert!(
8402            shaders.contains("kernel void copy_embedding_batch"),
8403            "shaders should contain copy_embedding_batch kernel"
8404        );
8405    }
8406
8407    #[test]
8408    fn generated_model_has_batch_pipelines() {
8409        let config = minimal_config();
8410        let model_rs = generate_model_rs(&config).unwrap();
8411
8412        for pipeline in &[
8413            "matmul_batch_pipeline",
8414            "matmul_q8_batch_pipeline",
8415            "matmul_q8_gemm_batch_pipeline",
8416            "matmul_q4_batch_pipeline",
8417            "rms_norm_batch_pipeline",
8418            "rope_batch_pipeline",
8419            "silu_mul_fused_batch_pipeline",
8420            "add_inplace_batch_pipeline",
8421            "copy_embedding_batch_pipeline",
8422            "attention_batch_pipeline",
8423            "copy_kv_batch_pipeline",
8424        ] {
8425            assert!(
8426                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8427                "MetalModel should have {pipeline} field"
8428            );
8429        }
8430    }
8431
8432    #[test]
8433    fn generated_model_has_batch_buffers() {
8434        let config = minimal_config();
8435        let model_rs = generate_model_rs(&config).unwrap();
8436
8437        for buf in &[
8438            "batch_hidden_buf",
8439            "batch_residual_buf",
8440            "batch_qkv_buf",
8441            "batch_attn_out_buf",
8442            "batch_attn_proj_buf",
8443            "batch_gate_up_buf",
8444            "batch_ffn_hidden_buf",
8445            "batch_ffn_out_buf",
8446            "batch_tokens_buf",
8447            "batch_positions_buf",
8448        ] {
8449            assert!(
8450                model_rs.contains(&format!("{buf}: Buffer")),
8451                "MetalModel should have {buf} field"
8452            );
8453        }
8454    }
8455
8456    #[test]
8457    fn generated_model_has_forward_prefill_batch() {
8458        let config = minimal_config();
8459        let model_rs = generate_model_rs(&config).unwrap();
8460
8461        assert!(
8462            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8463            "MetalModel should have forward_prefill_batch method"
8464        );
8465
8466        // forward_prefill should delegate to forward_prefill_batch
8467        assert!(
8468            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8469            "forward_prefill should delegate to forward_prefill_batch"
8470        );
8471    }
8472
8473    #[test]
8474    fn generated_model_has_max_batch_size_constant() {
8475        let config = minimal_config();
8476        let model_rs = generate_model_rs(&config).unwrap();
8477
8478        assert!(
8479            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8480            "model.rs should define MAX_BATCH_SIZE constant"
8481        );
8482    }
8483
8484    #[test]
8485    fn forward_prefill_batch_uses_batch_dispatch() {
8486        let config = minimal_config();
8487        let model_rs = generate_model_rs(&config).unwrap();
8488
8489        let batch_start = model_rs
8490            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8491            .unwrap();
8492        let batch_body = &model_rs[batch_start..];
8493        let batch_end = batch_body
8494            .find("\n    pub fn reset")
8495            .unwrap_or(batch_body.len());
8496        let batch_code = &batch_body[..batch_end];
8497
8498        // Should use batched dispatch methods
8499        assert!(
8500            batch_code.contains("dispatch_rms_norm_batch"),
8501            "forward_prefill_batch should use dispatch_rms_norm_batch"
8502        );
8503        assert!(
8504            batch_code.contains("dispatch_copy_embedding_batch"),
8505            "forward_prefill_batch should use dispatch_copy_embedding_batch"
8506        );
8507        assert!(
8508            batch_code.contains("dispatch_silu_mul_fused_batch"),
8509            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8510        );
8511        // Should use batched causal attention dispatch
8512        assert!(
8513            batch_code.contains("dispatch_attention_batch"),
8514            "forward_prefill_batch should use dispatch_attention_batch"
8515        );
8516        // Should use fused KV cache copy (both K and V in one dispatch)
8517        assert!(
8518            batch_code.contains("dispatch_copy_kv_both_batch"),
8519            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8520        );
8521        // Should use fused RoPE Q+K dispatch
8522        assert!(
8523            batch_code.contains("dispatch_rope_qk_batch"),
8524            "forward_prefill_batch should use dispatch_rope_qk_batch"
8525        );
8526    }
8527
8528    #[test]
8529    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8530        let config = minimal_q8_config();
8531        let model_rs = generate_model_rs(&config).unwrap();
8532
8533        let batch_start = model_rs
8534            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8535            .unwrap();
8536        let batch_body = &model_rs[batch_start..];
8537        let batch_end = batch_body
8538            .find("\n    pub fn reset")
8539            .unwrap_or(batch_body.len());
8540        let batch_code = &batch_body[..batch_end];
8541
8542        assert!(
8543            batch_code.contains("dispatch_matmul_q8_batch"),
8544            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8545        );
8546    }
8547
8548    #[test]
8549    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8550        let config = minimal_q4_config();
8551        let model_rs = generate_model_rs(&config).unwrap();
8552
8553        let batch_start = model_rs
8554            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8555            .unwrap();
8556        let batch_body = &model_rs[batch_start..];
8557        let batch_end = batch_body
8558            .find("\n    pub fn reset")
8559            .unwrap_or(batch_body.len());
8560        let batch_code = &batch_body[..batch_end];
8561
8562        assert!(
8563            batch_code.contains("dispatch_matmul_q4_batch"),
8564            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8565        );
8566    }
8567
8568    #[test]
8569    fn generated_main_rs_uses_batched_prefill() {
8570        let config = minimal_config();
8571        let main_rs = generate_main_rs("test-model", &config).unwrap();
8572
8573        assert!(
8574            main_rs.contains("forward_prefill_batch"),
8575            "main.rs should use forward_prefill_batch for prompt tokens"
8576        );
8577    }
8578
8579    #[test]
8580    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8581        let config = minimal_config();
8582        let model_rs = generate_model_rs(&config).unwrap();
8583
8584        let batch_start = model_rs
8585            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8586            .unwrap();
8587        let batch_body = &model_rs[batch_start..];
8588        let batch_end = batch_body
8589            .find("\n    pub fn reset")
8590            .unwrap_or(batch_body.len());
8591        let batch_code = &batch_body[..batch_end];
8592
8593        assert!(
8594            batch_code.contains("dispatch_matmul_batch"),
8595            "f32 forward_prefill_batch should use dispatch_matmul_batch"
8596        );
8597        // Should NOT use Q8 or Q4 batch dispatch
8598        assert!(
8599            !batch_code.contains("dispatch_matmul_q8_batch"),
8600            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8601        );
8602        assert!(
8603            !batch_code.contains("dispatch_matmul_q4_batch"),
8604            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8605        );
8606    }
8607
8608    #[test]
8609    fn forward_uses_cpu_embedding_lookup() {
8610        let config = minimal_config();
8611        let model_rs = generate_model_rs(&config).unwrap();
8612
8613        // Find just the forward() body (not forward_profile)
8614        let forward_start = model_rs
8615            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8616            .unwrap();
8617        let forward_body = &model_rs[forward_start..];
8618        let forward_end = forward_body
8619            .find("\n    pub fn forward_profile")
8620            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8621            .unwrap_or(forward_body.len());
8622        let forward_code = &forward_body[..forward_end];
8623
8624        // forward() should use CPU memcpy for embedding lookup (unified memory)
8625        assert!(
8626            forward_code.contains("embed_buf.contents()"),
8627            "forward() should access embed_buf via CPU unified memory for embedding lookup"
8628        );
8629        assert!(
8630            forward_code.contains("copy_nonoverlapping"),
8631            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8632        );
8633        // forward() should NOT use GPU dispatch for embedding
8634        assert!(
8635            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8636            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8637        );
8638    }
8639
8640    #[test]
8641    fn forward_profile_method_exists() {
8642        let config = minimal_config();
8643        let model_rs = generate_model_rs(&config).unwrap();
8644
8645        assert!(
8646            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8647            "MetalModel should have forward_profile() method"
8648        );
8649        // Profile method should print timing information
8650        assert!(
8651            model_rs.contains("[profile]"),
8652            "forward_profile() should print timing with [profile] prefix"
8653        );
8654        assert!(
8655            model_rs.contains("d_embed"),
8656            "forward_profile() should measure embedding time"
8657        );
8658        assert!(
8659            model_rs.contains("d_layers"),
8660            "forward_profile() should measure layer time"
8661        );
8662        assert!(
8663            model_rs.contains("d_logits"),
8664            "forward_profile() should measure logits time"
8665        );
8666    }
8667
8668    #[test]
8669    fn generated_cli_has_profile_flag() {
8670        let config = minimal_config();
8671        let main_rs = generate_main_rs("test-model", &config).unwrap();
8672
8673        assert!(
8674            main_rs.contains("--profile"),
8675            "CLI should support --profile flag"
8676        );
8677        assert!(
8678            main_rs.contains("forward_profile"),
8679            "CLI should call forward_profile when --profile is set"
8680        );
8681    }
8682
8683    #[test]
8684    fn generated_cli_has_thermal_yield() {
8685        let config = minimal_config();
8686        let main_rs = generate_main_rs("test-model", &config).unwrap();
8687
8688        assert!(
8689            main_rs.contains("yield_now()"),
8690            "CLI generation loop should include thread::yield_now() for thermal management"
8691        );
8692    }
8693
8694    // ── Real-world validation tests ──────────────────────────────────────
8695
8696    #[test]
8697    fn generated_forward_handles_single_token_prompt() {
8698        // With a single token (the first prompt token), forward() should work
8699        // at pos=0 where seq_len=1. The attention kernel must handle the case
8700        // where there is only one KV entry (no prefill context).
8701        let config = minimal_config();
8702        let model_rs = generate_model_rs(&config).unwrap();
8703
8704        // The forward function should accept any u32 token_id (no minimum pos guard)
8705        let forward_start = model_rs
8706            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8707            .expect("forward() must exist");
8708        let forward_body = &model_rs[forward_start..forward_start + 400];
8709
8710        // Should NOT require pos > 0 or seq_len > 1
8711        assert!(
8712            !forward_body.contains("assert!(self.pos > 0"),
8713            "forward() must accept pos=0 (first token with no prefill)"
8714        );
8715
8716        // The attention kernel should handle seq_len=1 via the pos field
8717        assert!(
8718            model_rs.contains("self.pos"),
8719            "forward() should use self.pos to track sequence position"
8720        );
8721    }
8722
8723    #[test]
8724    fn generated_reset_clears_kv_cache_position() {
8725        // After reset(), the model should be in a clean state. The pos field
8726        // must be 0 so new generation starts from scratch.
8727        let config = minimal_config();
8728        let model_rs = generate_model_rs(&config).unwrap();
8729
8730        let reset_start = model_rs
8731            .find("pub fn reset(&mut self)")
8732            .expect("reset() must exist");
8733        let reset_body = &model_rs[reset_start..reset_start + 200];
8734
8735        // Reset must zero the position counter
8736        assert!(
8737            reset_body.contains("self.pos = 0"),
8738            "reset() must set self.pos = 0"
8739        );
8740
8741        // Verify reset clears prev_cmd (double-buffering state)
8742        assert!(
8743            reset_body.contains("self.prev_cmd = None"),
8744            "reset() should clear prev_cmd for clean command buffer state"
8745        );
8746    }
8747
8748    #[test]
8749    fn generated_serve_handles_empty_messages_gracefully() {
8750        // The serve endpoint should not crash when receiving an empty messages array.
8751        // The format_chat_messages function should handle this gracefully.
8752        let config = minimal_config();
8753        let main_rs = generate_main_rs("test-model", &config).unwrap();
8754
8755        // The format_chat_messages function should exist and handle empty input
8756        let format_fn_start = main_rs
8757            .find("fn format_chat_messages")
8758            .expect("format_chat_messages must exist");
8759        let format_fn_body =
8760            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8761
8762        // It should iterate over messages (an empty slice produces an empty loop)
8763        assert!(
8764            format_fn_body.contains("for msg in messages"),
8765            "format_chat_messages should iterate over the messages slice"
8766        );
8767        // It should always append the assistant prompt suffix
8768        assert!(
8769            format_fn_body.contains("<|im_start|>assistant"),
8770            "format_chat_messages should always append assistant prompt header"
8771        );
8772
8773        // The serve function should call model.reset() before each request
8774        let serve_fn_start = main_rs
8775            .find("fn serve(")
8776            .expect("serve function must exist");
8777        let serve_fn_body = &main_rs[serve_fn_start..];
8778        assert!(
8779            serve_fn_body.contains("model.reset()"),
8780            "serve function should reset model between requests"
8781        );
8782    }
8783
8784    #[test]
8785    fn generated_model_forward_increments_pos() {
8786        // Each forward() call must increment self.pos so the next token
8787        // uses the correct RoPE position and KV cache offset.
8788        let config = minimal_config();
8789        let model_rs = generate_model_rs(&config).unwrap();
8790
8791        let forward_start = model_rs
8792            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8793            .unwrap();
8794        let forward_body = &model_rs[forward_start..];
8795        let forward_end = forward_body
8796            .find("\n    pub fn forward_profile")
8797            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8798            .or_else(|| forward_body.find("\n    fn dispatch_"))
8799            .unwrap_or(forward_body.len());
8800        let forward_code = &forward_body[..forward_end];
8801
8802        assert!(
8803            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8804            "forward() must increment self.pos after processing a token"
8805        );
8806    }
8807}