Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    // The attention-scores shared array must be at least effective_seq_len
122    // elements; anything smaller silently overflows for long prompts.  For
123    // small-context models (135M at 2K), a smaller array saves threadgroup
124    // memory and improves occupancy — so we size it precisely.
125    let attn_scores_size = config.max_seq_len.min(4096);
126    r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154    device const float* matrix [[buffer(0)]],
155    device const float* vector [[buffer(1)]],
156    device float* output       [[buffer(2)]],
157    constant uint& rows        [[buffer(3)]],
158    constant uint& cols        [[buffer(4)]],
159    uint tgid [[threadgroup_position_in_grid]],
160    uint tid [[thread_index_in_threadgroup]],
161    uint simd_lane [[thread_index_in_simdgroup]],
162    uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164    // Cooperatively load vector into threadgroup shared memory
165    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166    for (uint i = tid; i < cols; i += 256) {
167        vec_tile[i] = vector[i];
168    }
169    threadgroup_barrier(mem_flags::mem_threadgroup);
170
171    // Each simdgroup handles 8 consecutive rows
172    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173    if (row_base >= rows) return;
174
175    uint base0 = row_base * cols;
176    uint base1 = (row_base + 1) * cols;
177    uint base2 = (row_base + 2) * cols;
178    uint base3 = (row_base + 3) * cols;
179    uint base4 = (row_base + 4) * cols;
180    uint base5 = (row_base + 5) * cols;
181    uint base6 = (row_base + 6) * cols;
182    uint base7 = (row_base + 7) * cols;
183
184    // float4 vectorized accumulation across 8 rows
185    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
186    float4 sum4_0 = float4(0.0f);
187    float4 sum4_1 = float4(0.0f);
188    float4 sum4_2 = float4(0.0f);
189    float4 sum4_3 = float4(0.0f);
190    float4 sum4_4 = float4(0.0f);
191    float4 sum4_5 = float4(0.0f);
192    float4 sum4_6 = float4(0.0f);
193    float4 sum4_7 = float4(0.0f);
194
195    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205    }
206
207    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216    // Handle remaining elements (cols not divisible by 128)
217    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218        float vv = vec_tile[j];
219        sum0 += matrix[base0 + j] * vv;
220        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227    }
228
229    // Simdgroup hardware warp-level reduction
230    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235    // Only first lane writes the results
236    if (simd_lane == 0) {
237        if (row_base     < rows) output[row_base]     = sum0;
238        if (row_base + 1 < rows) output[row_base + 1] = sum1;
239        if (row_base + 2 < rows) output[row_base + 2] = sum2;
240        if (row_base + 3 < rows) output[row_base + 3] = sum3;
241        if (row_base + 4 < rows) output[row_base + 4] = sum4;
242        if (row_base + 5 < rows) output[row_base + 5] = sum5;
243        if (row_base + 6 < rows) output[row_base + 6] = sum6;
244        if (row_base + 7 < rows) output[row_base + 7] = sum7;
245    }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253    device const float* input   [[buffer(0)]],
254    device const float* weight  [[buffer(1)]],
255    device float* output        [[buffer(2)]],
256    constant uint& n            [[buffer(3)]],
257    constant float& eps         [[buffer(4)]],
258    uint tid [[thread_index_in_threadgroup]])
259{
260    // Each thread accumulates partial sum-of-squares
261    float sum_sq = 0.0f;
262    for (uint i = tid; i < n; i += 256) {
263        float v = input[i];
264        sum_sq += v * v;
265    }
266
267    // Simdgroup-level reduction (hardware warp sum)
268    sum_sq = simd_sum(sum_sq);
269
270    // Cross-simdgroup reduction via shared memory
271    threadgroup float shared[8];
272    uint simd_id = tid / 32;
273    uint simd_lane = tid % 32;
274    if (simd_lane == 0) {
275        shared[simd_id] = sum_sq;
276    }
277    threadgroup_barrier(mem_flags::mem_threadgroup);
278
279    // First thread computes final inverse RMS
280    if (tid == 0) {
281        float total = 0.0f;
282        for (uint i = 0; i < 8; i++) {
283            total += shared[i];
284        }
285        shared[0] = fast::rsqrt(total / float(n) + eps);
286    }
287    threadgroup_barrier(mem_flags::mem_threadgroup);
288
289    float inv_rms = shared[0];
290
291    // Normalize
292    for (uint i = tid; i < n; i += 256) {
293        output[i] = input[i] * inv_rms * weight[i];
294    }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301    device float* data        [[buffer(0)]],
302    constant uint& num_heads  [[buffer(1)]],
303    constant uint& head_dim   [[buffer(2)]],
304    constant uint& pos        [[buffer(3)]],
305    constant float& theta     [[buffer(4)]],
306    uint id [[thread_position_in_grid]])
307{
308    uint half_dim = head_dim / 2;
309    uint total_pairs = num_heads * half_dim;
310    if (id >= total_pairs) return;
311
312    uint h = id / half_dim;
313    uint i = id % half_dim;
314    uint off = h * head_dim;
315
316    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317    float angle = float(pos) * freq;
318    float c = cos(angle);
319    float s = sin(angle);
320
321    float x0 = data[off + 2 * i];
322    float x1 = data[off + 2 * i + 1];
323    data[off + 2 * i]     = x0 * c - x1 * s;
324    data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331    device float* data       [[buffer(0)]],
332    constant uint& n         [[buffer(1)]],
333    uint tid [[thread_index_in_threadgroup]],
334    uint tg_size [[threads_per_threadgroup]])
335{
336    threadgroup float shared_val[256];
337
338    // Pass 1: find max
339    float local_max = -INFINITY;
340    for (uint i = tid; i < n; i += tg_size) {
341        local_max = max(local_max, data[i]);
342    }
343    shared_val[tid] = local_max;
344    threadgroup_barrier(mem_flags::mem_threadgroup);
345
346    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347        if (tid < stride) {
348            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349        }
350        threadgroup_barrier(mem_flags::mem_threadgroup);
351    }
352    float max_val = shared_val[0];
353    threadgroup_barrier(mem_flags::mem_threadgroup);
354
355    // Pass 2: exp and sum
356    float local_sum = 0.0f;
357    for (uint i = tid; i < n; i += tg_size) {
358        float e = fast::exp(data[i] - max_val);
359        data[i] = e;
360        local_sum += e;
361    }
362    shared_val[tid] = local_sum;
363    threadgroup_barrier(mem_flags::mem_threadgroup);
364
365    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366        if (tid < stride) {
367            shared_val[tid] += shared_val[tid + stride];
368        }
369        threadgroup_barrier(mem_flags::mem_threadgroup);
370    }
371    float inv_sum = 1.0f / shared_val[0];
372    threadgroup_barrier(mem_flags::mem_threadgroup);
373
374    // Pass 3: normalize
375    for (uint i = tid; i < n; i += tg_size) {
376        data[i] *= inv_sum;
377    }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384    device const float* gate [[buffer(0)]],
385    device const float* up   [[buffer(1)]],
386    device float* output     [[buffer(2)]],
387    constant uint& n         [[buffer(3)]],
388    uint id [[thread_position_in_grid]])
389{
390    if (id >= n) return;
391    float g = gate[id];
392    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397//   gate = gate_up[0..n], up = gate_up[n..2*n]
398//   output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400    device const float* gate_up [[buffer(0)]],
401    device float* output        [[buffer(1)]],
402    constant uint& n            [[buffer(2)]],
403    uint id [[thread_position_in_grid]])
404{
405    if (id >= n) return;
406    float g = gate_up[id];
407    float u = gate_up[n + id];
408    output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414    device const float* a  [[buffer(0)]],
415    device const float* b  [[buffer(1)]],
416    device float* output   [[buffer(2)]],
417    constant uint& n       [[buffer(3)]],
418    uint id [[thread_position_in_grid]])
419{
420    if (id >= n) return;
421    output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428    device const float* src [[buffer(0)]],
429    device float* dst       [[buffer(1)]],
430    constant uint& count    [[buffer(2)]],
431    uint id [[thread_position_in_grid]])
432{
433    if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440    device const float* src     [[buffer(0)]],
441    device float* dst           [[buffer(1)]],
442    constant uint& src_offset   [[buffer(2)]],  // in floats
443    constant uint& count        [[buffer(3)]],
444    uint id [[thread_position_in_grid]])
445{
446    if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── copy_f32_to_f16_offset ──────────────────────────────────────────────
450// Copy f32 elements from src into a half-typed dst, converting to half on
451// write.  Used by the single-token decode path to append a new K/V vector
452// to the f16 KV cache.  Byte offsets into src/dst are supplied via the
453// Metal buffer binding offsets — no in-kernel offsets needed.
454kernel void copy_f32_to_f16_offset(
455    device const float* src     [[buffer(0)]],
456    device half* dst            [[buffer(1)]],
457    constant uint& count        [[buffer(2)]],
458    uint id [[thread_position_in_grid]])
459{
460    if (id < count) dst[id] = half(src[id]);
461}
462
463// ── add_inplace ─────────────────────────────────────────────────────────
464// In-place residual connection: a[i] += b[i]
465// Avoids a separate blit copy for residual add, reducing encoder overhead.
466kernel void add_inplace(
467    device float* a        [[buffer(0)]],
468    device const float* b  [[buffer(1)]],
469    constant uint& n       [[buffer(2)]],
470    uint id [[thread_position_in_grid]])
471{
472    if (id >= n) return;
473    a[id] += b[id];
474}
475
476// ── matmul_vec_q8 ─────────────────────────────────────────────────────
477// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
478// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
479// Operates directly on quantized weights to halve memory bandwidth vs f32,
480// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
481//
482// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
483// because int8->float conversion doubles register demand.  Fully unrolled
484// inner loop with float4 vector loads from shared memory eliminates loop
485// overhead and enables better instruction scheduling.
486// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
487constant constexpr uint Q8_ROWS_PER_SG = 4;
488constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
489
490// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
491// doubles ALU work, so the same register budget applies).
492constant constexpr uint Q4_ROWS_PER_SG = 4;
493constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
494
495kernel void matmul_vec_q8(
496    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
497    device const float* vector   [[buffer(1)]],  // f32 input
498    device float* output         [[buffer(2)]],
499    constant uint& rows          [[buffer(3)]],
500    constant uint& cols          [[buffer(4)]],  // number of elements per row
501    uint tgid [[threadgroup_position_in_grid]],
502    uint tid [[thread_index_in_threadgroup]],
503    uint simd_lane [[thread_index_in_simdgroup]],
504    uint simd_id [[simdgroup_index_in_threadgroup]])
505{
506    // Load vector into shared memory
507    threadgroup float vec_tile[VEC_TILE_SIZE];
508    for (uint i = tid; i < cols; i += 256) {
509        vec_tile[i] = vector[i];
510    }
511    threadgroup_barrier(mem_flags::mem_threadgroup);
512
513    // Each simdgroup handles 4 consecutive rows (lower register pressure)
514    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
515    if (row_base >= rows) return;
516
517    // Q8_0: each block is 34 bytes for 32 elements
518    uint blocks_per_row = cols / 32;
519    uint row_bytes = blocks_per_row * 34;
520
521    // Pointers to each row's Q8_0 data
522    device const uchar* r0 = matrix + row_base * row_bytes;
523    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
524    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
525    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
526
527    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
528
529    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
530        uint bb = blk * 34;
531        uint vb = blk * 32;
532
533        // Prefetch all 4 scales
534        float sc0 = float(*(device const half*)(r0 + bb));
535        float sc1 = float(*(device const half*)(r1 + bb));
536        float sc2 = float(*(device const half*)(r2 + bb));
537        float sc3 = float(*(device const half*)(r3 + bb));
538
539        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
540        // Q8_0 block layout where the int8 data starts at offset +2 from a
541        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
542        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
543        // reduction in memory transactions. Metal's char16/packed_char16 are
544        // reserved types and packed_*int4 require >=4-byte alignment which
545        // this layout does not provide, so packed_short4 is the widest valid
546        // vectorized load.
547        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
548        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
549        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
550        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
551
552        // Load all 8 float4 vector values for this 32-element block from shared memory
553        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
554        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
555        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
556        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
557        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
558        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
559        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
560        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
561
562        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
563        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
564        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
565            short4 _s = short4(SHORT4); \
566            char2 _a = as_type<char2>(_s.x); \
567            char2 _b = as_type<char2>(_s.y); \
568            char2 _c = as_type<char2>(_s.z); \
569            char2 _d = as_type<char2>(_s.w); \
570            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
571            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
572        }
573
574        float4 f0, f1;
575        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
576
577        // Row 0: 4 short4 loads cover 32 int8 weights
578        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
579        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
580        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
581        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
582
583        // Row 1
584        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
585        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
586        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
587        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
588
589        // Row 2
590        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
591        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
592        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
593        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
594
595        // Row 3
596        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
597        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
598        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
599        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
600
601        #undef Q8_UNPACK8
602
603        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
604    }
605
606    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
607    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
608
609    if (simd_lane == 0) {
610        if (row_base     < rows) output[row_base]     = sum0;
611        if (row_base + 1 < rows) output[row_base + 1] = sum1;
612        if (row_base + 2 < rows) output[row_base + 2] = sum2;
613        if (row_base + 3 < rows) output[row_base + 3] = sum3;
614    }
615}
616
617// ── matmul_vec_q4 ─────────────────────────────────────────────────────
618// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
619// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
620// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
621// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
622//
623// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
624// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
625kernel void matmul_vec_q4(
626    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
627    device const float* vector   [[buffer(1)]],  // f32 input
628    device float* output         [[buffer(2)]],
629    constant uint& rows          [[buffer(3)]],
630    constant uint& cols          [[buffer(4)]],  // number of elements per row
631    uint tgid [[threadgroup_position_in_grid]],
632    uint tid [[thread_index_in_threadgroup]],
633    uint simd_lane [[thread_index_in_simdgroup]],
634    uint simd_id [[simdgroup_index_in_threadgroup]])
635{
636    // Load vector into shared memory
637    threadgroup float vec_tile[VEC_TILE_SIZE];
638    for (uint i = tid; i < cols; i += 256) {
639        vec_tile[i] = vector[i];
640    }
641    threadgroup_barrier(mem_flags::mem_threadgroup);
642
643    // Each simdgroup handles 4 consecutive rows
644    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
645    if (row_base >= rows) return;
646
647    // Q4_0: each block is 18 bytes for 32 elements
648    uint blocks_per_row = cols / 32;
649    uint row_bytes = blocks_per_row * 18;
650
651    // Pointers to each row's Q4_0 data
652    device const uchar* r0 = matrix + row_base * row_bytes;
653    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
654    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
655    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
656
657    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
658
659    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
660        uint bb = blk * 18;
661        uint vb = blk * 32;
662
663        // Prefetch all 4 scales
664        float sc0 = float(*(device const half*)(r0 + bb));
665        float sc1 = float(*(device const half*)(r1 + bb));
666        float sc2 = float(*(device const half*)(r2 + bb));
667        float sc3 = float(*(device const half*)(r3 + bb));
668
669        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
670        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
671        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
672        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
673        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
674
675        // Load 8 float4 vector values for 32 elements from shared memory
676        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
677        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
678        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
679        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
680        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
681        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
682        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
683        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
684        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
685
686        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
687        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
688        float bd0=0, bd1=0, bd2=0, bd3=0;
689        uchar4 b;
690
691        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
692        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
693                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
694                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
695                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
696        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
697                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
698                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
699                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
700        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
701                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
702                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
703                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
704        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
705                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
706                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
707                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
708
709        // Row 1
710        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
711                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
712                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
713                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
714        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
715                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
716                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
717                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
718        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
719                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
720                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
721                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
722        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
723                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
724                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
725                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
726
727        // Row 2
728        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
729                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
730                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
731                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
732        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
733                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
734                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
735                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
736        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
737                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
738                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
739                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
740        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
741                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
742                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
743                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
744
745        // Row 3
746        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
747                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
748                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
749                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
750        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
751                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
752                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
753                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
754        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
755                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
756                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
757                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
758        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
759                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
760                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
761                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
762
763        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
764    }
765
766    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
767    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
768
769    if (simd_lane == 0) {
770        if (row_base     < rows) output[row_base]     = sum0;
771        if (row_base + 1 < rows) output[row_base + 1] = sum1;
772        if (row_base + 2 < rows) output[row_base + 2] = sum2;
773        if (row_base + 3 < rows) output[row_base + 3] = sum3;
774    }
775}
776
777// ── attention ───────────────────────────────────────────────────────────
778// Single-query attention with simdgroup cooperative reductions.
779// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
780// with simd_max/simd_sum reductions, then weighted sum of V.
781// Each threadgroup handles one head with 256 threads (8 simdgroups).
782//
783// Buffers:
784//   q:       [num_heads * head_dim]       current query
785//   k_cache: [max_seq_len * num_kv_heads * head_dim]
786//   v_cache: [max_seq_len * num_kv_heads * head_dim]
787//   output:  [num_heads * head_dim]
788kernel void attention(
789    device const float* q        [[buffer(0)]],
790    device const half*  k_cache  [[buffer(1)]],
791    device const half*  v_cache  [[buffer(2)]],
792    device float* output         [[buffer(3)]],
793    constant uint& seq_len       [[buffer(4)]],
794    constant uint& num_heads     [[buffer(5)]],
795    constant uint& num_kv_heads  [[buffer(6)]],
796    constant uint& head_dim      [[buffer(7)]],
797    uint tgid [[threadgroup_position_in_grid]],
798    uint tid [[thread_index_in_threadgroup]],
799    uint simd_lane [[thread_index_in_simdgroup]],
800    uint simd_id [[simdgroup_index_in_threadgroup]])
801{
802    uint head = tgid;
803    if (head >= num_heads) return;
804    uint kv_head = head / (num_heads / num_kv_heads);
805
806    uint q_off = head * head_dim;
807
808    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
809    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
810    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
811    // most generation steps have short effective context.
812    threadgroup float scores[ATTN_SCORES_SIZE];  // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
813
814    // Q·K^T with half4/float4 vectorized loads.
815    // Each simdgroup handles one s; 32 lanes cover head_dim in chunks of 4.
816    // For head_dim=128, every lane does exactly one half4 load (no loop).
817    // For head_dim=64, 16 lanes active; for head_dim=96, 24 lanes.
818    uint head_dim4 = head_dim / 4;
819    for (uint s = simd_id; s < seq_len; s += 8) {
820        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
821        float dot = 0.0;
822        for (uint d4 = simd_lane; d4 < head_dim4; d4 += 32) {
823            uint d = d4 * 4;
824            float4 q4 = *(device const float4*)(q + q_off + d);
825            float4 k4 = float4(*(device const half4*)(k_cache + k_off + d));
826            dot += q4.x * k4.x + q4.y * k4.y + q4.z * k4.z + q4.w * k4.w;
827        }
828        // Scalar fallback for head_dim not divisible by 4 (unused for all
829        // current models: head_dim ∈ {64, 96, 128} are all multiples of 4).
830        for (uint d = head_dim4 * 4 + simd_lane; d < head_dim; d += 32) {
831            dot += q[q_off + d] * float(k_cache[k_off + d]);
832        }
833        dot = simd_sum(dot);
834        if (simd_lane == 0) {
835            scores[s] = dot * fast::rsqrt(float(head_dim));
836        }
837    }
838    threadgroup_barrier(mem_flags::mem_threadgroup);
839
840    // Step 2: Softmax over scores (cooperative)
841    // Find max
842    float local_max = -INFINITY;
843    for (uint s = tid; s < seq_len; s += 256) {
844        local_max = max(local_max, scores[s]);
845    }
846    local_max = simd_max(local_max);
847    threadgroup float shared_max[8];
848    if (simd_lane == 0) shared_max[simd_id] = local_max;
849    threadgroup_barrier(mem_flags::mem_threadgroup);
850    if (tid == 0) {
851        float m = shared_max[0];
852        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
853        shared_max[0] = m;
854    }
855    threadgroup_barrier(mem_flags::mem_threadgroup);
856    float max_val = shared_max[0];
857
858    // Exp and sum
859    float local_sum = 0.0;
860    for (uint s = tid; s < seq_len; s += 256) {
861        scores[s] = fast::exp(scores[s] - max_val);
862        local_sum += scores[s];
863    }
864    local_sum = simd_sum(local_sum);
865    threadgroup float shared_sum[8];
866    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
867    threadgroup_barrier(mem_flags::mem_threadgroup);
868    if (tid == 0) {
869        float total = 0.0;
870        for (uint i = 0; i < 8; i++) total += shared_sum[i];
871        shared_sum[0] = 1.0 / total;
872    }
873    threadgroup_barrier(mem_flags::mem_threadgroup);
874    float inv_sum = shared_sum[0];
875
876    for (uint s = tid; s < seq_len; s += 256) {
877        scores[s] *= inv_sum;
878    }
879    threadgroup_barrier(mem_flags::mem_threadgroup);
880
881    // Step 3: Weighted sum of V: output = scores · V.
882    //
883    // v0.7.4 restructure: one simdgroup per d4 chunk. The 32 lanes in a
884    // simdgroup partition the seq_len dimension (lane handles positions
885    // s = simd_lane, simd_lane+32, simd_lane+64, ...) and reduce their
886    // partial sums via simd_sum.  This keeps all 256 threads productive
887    // for head_dim ∈ {64, 96, 128}; the old "one thread per d4" layout
888    // kept only 16/256 threads productive for head_dim=64.
889    //
890    // 8 simdgroups handle 8 d4s per outer iteration:
891    //   head_dim=64  (head_dim4=16): 2 iters
892    //   head_dim=96  (head_dim4=24): 3 iters
893    //   head_dim=128 (head_dim4=32): 4 iters
894    uint v_stride = num_kv_heads * head_dim;
895    for (uint d4 = simd_id; d4 < head_dim4; d4 += 8) {
896        uint d = d4 * 4;
897        uint v_base = kv_head * head_dim + d;
898        float4 partial = float4(0.0);
899        for (uint s = simd_lane; s < seq_len; s += 32) {
900            half4 v = *(device const half4*)(v_cache + s * v_stride + v_base);
901            partial += scores[s] * float4(v);
902        }
903        // Reduce the 32 lanes of this simdgroup to lane 0 per component.
904        partial.x = simd_sum(partial.x);
905        partial.y = simd_sum(partial.y);
906        partial.z = simd_sum(partial.z);
907        partial.w = simd_sum(partial.w);
908        if (simd_lane == 0) {
909            *(device float4*)(output + q_off + d) = partial;
910        }
911    }
912    // Scalar fallback for dims not divisible by 4 (unused for current models —
913    // all supported head_dims are in {64, 96, 128}). Same simd-per-d layout.
914    for (uint d = head_dim4 * 4 + simd_id; d < head_dim; d += 8) {
915        uint v_base = kv_head * head_dim + d;
916        float partial = 0.0;
917        for (uint s = simd_lane; s < seq_len; s += 32) {
918            partial += scores[s] * float(v_cache[s * v_stride + v_base]);
919        }
920        partial = simd_sum(partial);
921        if (simd_lane == 0) {
922            output[q_off + d] = partial;
923        }
924    }
925}
926
927// ── Batched prefill kernels ────────────────────────────────────────────
928// These kernels process M input vectors against the same weight matrix
929// in a single dispatch, converting mat-vec into mat-mat for better GPU
930// utilization during prompt prefill.
931
932// ── rms_norm_batch ─────────────────────────────────────────────────────
933// RMS normalization for a batch of vectors.
934// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
935// Grid: M threadgroups (one per token).
936kernel void rms_norm_batch(
937    device const float* input   [[buffer(0)]],  // [M, n]
938    device const float* weight  [[buffer(1)]],  // [n]
939    device float* output        [[buffer(2)]],  // [M, n]
940    constant uint& n            [[buffer(3)]],
941    constant float& eps         [[buffer(4)]],
942    constant uint& num_tokens   [[buffer(5)]],
943    uint tgid [[threadgroup_position_in_grid]],
944    uint tid [[thread_index_in_threadgroup]])
945{
946    if (tgid >= num_tokens) return;
947
948    uint base = tgid * n;
949
950    float sum_sq = 0.0f;
951    for (uint i = tid; i < n; i += 256) {
952        float v = input[base + i];
953        sum_sq += v * v;
954    }
955
956    sum_sq = simd_sum(sum_sq);
957
958    threadgroup float shared[8];
959    uint simd_id = tid / 32;
960    uint simd_lane = tid % 32;
961    if (simd_lane == 0) {
962        shared[simd_id] = sum_sq;
963    }
964    threadgroup_barrier(mem_flags::mem_threadgroup);
965
966    if (tid == 0) {
967        float total = 0.0f;
968        for (uint i = 0; i < 8; i++) {
969            total += shared[i];
970        }
971        shared[0] = fast::rsqrt(total / float(n) + eps);
972    }
973    threadgroup_barrier(mem_flags::mem_threadgroup);
974
975    float inv_rms = shared[0];
976
977    for (uint i = tid; i < n; i += 256) {
978        output[base + i] = input[base + i] * inv_rms * weight[i];
979    }
980}
981
982// ── rope_batch ─────────────────────────────────────────────────────────
983// Rotary Position Embedding for a batch of vectors with different positions.
984// data layout: [M, num_heads * head_dim], positions: [M]
985// Each thread handles one (token, head, pair) combination.
986kernel void rope_batch(
987    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
988    constant uint& num_heads     [[buffer(1)]],
989    constant uint& head_dim      [[buffer(2)]],
990    device const uint* positions  [[buffer(3)]],  // [M] position per token
991    constant float& theta        [[buffer(4)]],
992    constant uint& num_tokens    [[buffer(5)]],
993    uint id [[thread_position_in_grid]])
994{
995    uint half_dim = head_dim / 2;
996    uint pairs_per_token = num_heads * half_dim;
997    uint total = num_tokens * pairs_per_token;
998    if (id >= total) return;
999
1000    uint token = id / pairs_per_token;
1001    uint rem = id % pairs_per_token;
1002    uint h = rem / half_dim;
1003    uint i = rem % half_dim;
1004    uint off = token * (num_heads * head_dim) + h * head_dim;
1005
1006    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
1007    float angle = float(positions[token]) * freq;
1008    float c = cos(angle);
1009    float s = sin(angle);
1010
1011    float x0 = data[off + 2 * i];
1012    float x1 = data[off + 2 * i + 1];
1013    data[off + 2 * i]     = x0 * c - x1 * s;
1014    data[off + 2 * i + 1] = x0 * s + x1 * c;
1015}
1016
1017// ── silu_mul_fused_batch ───────────────────────────────────────────────
1018// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
1019// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
1020kernel void silu_mul_fused_batch(
1021    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
1022    device float* output        [[buffer(1)]],  // [M, n]
1023    constant uint& n            [[buffer(2)]],
1024    constant uint& num_tokens   [[buffer(3)]],
1025    uint id [[thread_position_in_grid]])
1026{
1027    uint total = num_tokens * n;
1028    if (id >= total) return;
1029    uint token = id / n;
1030    uint i = id % n;
1031    uint gu_base = token * 2 * n;
1032    float g = gate_up[gu_base + i];
1033    float u = gate_up[gu_base + n + i];
1034    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
1035}
1036
1037// ── add_inplace_batch ──────────────────────────────────────────────────
1038// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
1039kernel void add_inplace_batch(
1040    device float* a        [[buffer(0)]],  // [M * n]
1041    device const float* b  [[buffer(1)]],  // [M * n]
1042    constant uint& total   [[buffer(2)]],  // M * n
1043    uint id [[thread_position_in_grid]])
1044{
1045    if (id >= total) return;
1046    a[id] += b[id];
1047}
1048
1049// ── copy_embedding_batch ───────────────────────────────────────────────
1050// Copy M embedding rows from embedding table to a contiguous batch buffer.
1051// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1052kernel void copy_embedding_batch(
1053    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1054    device float* output        [[buffer(1)]],  // [M, dim]
1055    device const uint* tokens   [[buffer(2)]],  // [M]
1056    constant uint& dim          [[buffer(3)]],
1057    constant uint& num_tokens   [[buffer(4)]],
1058    uint id [[thread_position_in_grid]])
1059{
1060    uint total = num_tokens * dim;
1061    if (id >= total) return;
1062    uint token_idx = id / dim;
1063    uint d = id % dim;
1064    output[id] = embed[tokens[token_idx] * dim + d];
1065}
1066
1067// ── matmul_vec_batch ───────────────────────────────────────────────────
1068// Batched matrix-vector multiply: process M input vectors against the same
1069// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1070// Each threadgroup handles one (token, row_group) pair.
1071kernel void matmul_vec_batch(
1072    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1073    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1074    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1075    constant uint& num_tokens   [[buffer(3)]],  // M
1076    constant uint& rows         [[buffer(4)]],
1077    constant uint& cols         [[buffer(5)]],
1078    uint tgid [[threadgroup_position_in_grid]],
1079    uint tid [[thread_index_in_threadgroup]],
1080    uint simd_lane [[thread_index_in_simdgroup]],
1081    uint simd_id [[simdgroup_index_in_threadgroup]])
1082{
1083    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1084    uint token = tgid / row_tgs;
1085    uint tg_in_token = tgid % row_tgs;
1086    if (token >= num_tokens) return;
1087
1088    // Load this token's input vector into shared memory
1089    threadgroup float vec_tile[VEC_TILE_SIZE];
1090    device const float* input = inputs + token * cols;
1091    for (uint i = tid; i < cols; i += 256) {
1092        vec_tile[i] = input[i];
1093    }
1094    threadgroup_barrier(mem_flags::mem_threadgroup);
1095
1096    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1097    if (row_base >= rows) return;
1098
1099    uint base0 = row_base * cols;
1100    uint base1 = (row_base + 1) * cols;
1101    uint base2 = (row_base + 2) * cols;
1102    uint base3 = (row_base + 3) * cols;
1103    uint base4 = (row_base + 4) * cols;
1104    uint base5 = (row_base + 5) * cols;
1105    uint base6 = (row_base + 6) * cols;
1106    uint base7 = (row_base + 7) * cols;
1107
1108    uint cols_vec4 = cols & ~127u;
1109    float4 sum4_0 = float4(0.0f);
1110    float4 sum4_1 = float4(0.0f);
1111    float4 sum4_2 = float4(0.0f);
1112    float4 sum4_3 = float4(0.0f);
1113    float4 sum4_4 = float4(0.0f);
1114    float4 sum4_5 = float4(0.0f);
1115    float4 sum4_6 = float4(0.0f);
1116    float4 sum4_7 = float4(0.0f);
1117
1118    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1119        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1120        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1121        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1122        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1123        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1124        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1125        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1126        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1127        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1128    }
1129
1130    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1131    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1132    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1133    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1134    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1135    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1136    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1137    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1138
1139    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1140        float vv = vec_tile[j];
1141        sum0 += matrix[base0 + j] * vv;
1142        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1143        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1144        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1145        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1146        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1147        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1148        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1149    }
1150
1151    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1152    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1153    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1154    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1155
1156    device float* output = outputs + token * rows;
1157    if (simd_lane == 0) {
1158        if (row_base     < rows) output[row_base]     = sum0;
1159        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1160        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1161        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1162        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1163        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1164        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1165        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1166    }
1167}
1168
1169// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1170// Batched Q8_0 matrix-vector multiply for M input vectors.
1171// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1172kernel void matmul_vec_q8_batch(
1173    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1174    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1175    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1176    constant uint& num_tokens    [[buffer(3)]],  // M
1177    constant uint& rows          [[buffer(4)]],
1178    constant uint& cols          [[buffer(5)]],
1179    uint tgid [[threadgroup_position_in_grid]],
1180    uint tid [[thread_index_in_threadgroup]],
1181    uint simd_lane [[thread_index_in_simdgroup]],
1182    uint simd_id [[simdgroup_index_in_threadgroup]])
1183{
1184    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1185    uint token = tgid / row_tgs;
1186    uint tg_in_token = tgid % row_tgs;
1187    if (token >= num_tokens) return;
1188
1189    // Load this token's input vector into shared memory
1190    threadgroup float vec_tile[VEC_TILE_SIZE];
1191    device const float* input = inputs + token * cols;
1192    for (uint i = tid; i < cols; i += 256) {
1193        vec_tile[i] = input[i];
1194    }
1195    threadgroup_barrier(mem_flags::mem_threadgroup);
1196
1197    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1198    if (row_base >= rows) return;
1199
1200    uint blocks_per_row = cols / 32;
1201    uint row_bytes = blocks_per_row * 34;
1202
1203    device const uchar* r0 = matrix + row_base * row_bytes;
1204    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1205    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1206    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1207
1208    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1209
1210    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1211        uint bb = blk * 34;
1212        uint vb = blk * 32;
1213
1214        float sc0 = float(*(device const half*)(r0 + bb));
1215        float sc1 = float(*(device const half*)(r1 + bb));
1216        float sc2 = float(*(device const half*)(r2 + bb));
1217        float sc3 = float(*(device const half*)(r3 + bb));
1218
1219        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1220        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1221        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1222        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1223        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1224        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1225
1226        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1227        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1228        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1229        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1230        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1231        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1232        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1233        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1234
1235        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1236            short4 _s = short4(SHORT4); \
1237            char2 _a = as_type<char2>(_s.x); \
1238            char2 _b = as_type<char2>(_s.y); \
1239            char2 _c = as_type<char2>(_s.z); \
1240            char2 _d = as_type<char2>(_s.w); \
1241            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1242            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1243        }
1244
1245        float4 f0, f1;
1246        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1247
1248        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1249        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1250        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1251        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1252
1253        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1254        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1255        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1256        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1257
1258        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1259        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1260        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1261        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1262
1263        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1264        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1265        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1266        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1267
1268        #undef Q8_UNPACK8
1269
1270        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1271    }
1272
1273    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1274    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1275
1276    device float* output = outputs + token * rows;
1277    if (simd_lane == 0) {
1278        if (row_base     < rows) output[row_base]     = sum0;
1279        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1280        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1281        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1282    }
1283}
1284
1285// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1286// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1287// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1288// the Q8_0 weight blocks are fetched once from device memory and reused for
1289// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1290// per-token dispatch).
1291//
1292// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1293// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1294// with simd_sum.  Token vectors are read directly from device memory inside
1295// the block loop (not cached in shared memory) so intermediate_size up to
1296// 8192 fits without spilling threadgroup memory.
1297constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1298
1299kernel void matmul_q8_gemm_batch(
1300    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1301    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1302    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1303    constant uint& num_tokens    [[buffer(3)]],  // M
1304    constant uint& rows          [[buffer(4)]],
1305    constant uint& cols          [[buffer(5)]],
1306    uint2 tgid [[threadgroup_position_in_grid]],
1307    uint tid [[thread_index_in_threadgroup]],
1308    uint simd_lane [[thread_index_in_simdgroup]],
1309    uint simd_id [[simdgroup_index_in_threadgroup]])
1310{
1311    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1312    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1313    if (row_base >= rows || tok_base >= num_tokens) return;
1314
1315    // How many tokens in this tile are valid?
1316    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1317
1318    uint blocks_per_row = cols / 32;
1319    uint row_bytes = blocks_per_row * 34;
1320
1321    device const uchar* r0 = matrix + row_base * row_bytes;
1322    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1323    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1324    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1325
1326    // Accumulators: 4 tokens × 4 rows per simdgroup.
1327    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1328    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1329    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1330    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1331
1332    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1333        uint bb = blk * 34;
1334        uint vb = blk * 32;
1335
1336        // ── Load weight data ONCE per block (reused across all tokens) ──
1337        float sc0 = float(*(device const half*)(r0 + bb));
1338        float sc1 = float(*(device const half*)(r1 + bb));
1339        float sc2 = float(*(device const half*)(r2 + bb));
1340        float sc3 = float(*(device const half*)(r3 + bb));
1341
1342        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1343        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1344        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1345        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1346
1347        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1348            short4 _s = short4(SHORT4); \
1349            char2 _a = as_type<char2>(_s.x); \
1350            char2 _b = as_type<char2>(_s.y); \
1351            char2 _c = as_type<char2>(_s.z); \
1352            char2 _d = as_type<char2>(_s.w); \
1353            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1354            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1355        }
1356
1357        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1358        // registers for the duration of the block and are dotted against
1359        // every token's vector tile.
1360        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1361        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1362        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1363        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1364
1365        Q8_UNPACK8(d0[0], w0_0, w0_1);
1366        Q8_UNPACK8(d0[1], w0_2, w0_3);
1367        Q8_UNPACK8(d0[2], w0_4, w0_5);
1368        Q8_UNPACK8(d0[3], w0_6, w0_7);
1369
1370        Q8_UNPACK8(d1[0], w1_0, w1_1);
1371        Q8_UNPACK8(d1[1], w1_2, w1_3);
1372        Q8_UNPACK8(d1[2], w1_4, w1_5);
1373        Q8_UNPACK8(d1[3], w1_6, w1_7);
1374
1375        Q8_UNPACK8(d2[0], w2_0, w2_1);
1376        Q8_UNPACK8(d2[1], w2_2, w2_3);
1377        Q8_UNPACK8(d2[2], w2_4, w2_5);
1378        Q8_UNPACK8(d2[3], w2_6, w2_7);
1379
1380        Q8_UNPACK8(d3[0], w3_0, w3_1);
1381        Q8_UNPACK8(d3[1], w3_2, w3_3);
1382        Q8_UNPACK8(d3[2], w3_4, w3_5);
1383        Q8_UNPACK8(d3[3], w3_6, w3_7);
1384
1385        #undef Q8_UNPACK8
1386
1387        // ── For each token, read vector and accumulate against shared weights ──
1388        // Token 0 (always valid: tok_count >= 1).
1389        {
1390            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1391            float4 v0 = *(device const float4*)(a0);
1392            float4 v1 = *(device const float4*)(a0 + 4);
1393            float4 v2 = *(device const float4*)(a0 + 8);
1394            float4 v3 = *(device const float4*)(a0 + 12);
1395            float4 v4 = *(device const float4*)(a0 + 16);
1396            float4 v5 = *(device const float4*)(a0 + 20);
1397            float4 v6 = *(device const float4*)(a0 + 24);
1398            float4 v7 = *(device const float4*)(a0 + 28);
1399            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1400                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1401            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1402                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1403            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1404                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1405            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1406                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1407            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1408        }
1409        // Token 1
1410        if (tok_count > 1) {
1411            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1412            float4 v0 = *(device const float4*)(a1);
1413            float4 v1 = *(device const float4*)(a1 + 4);
1414            float4 v2 = *(device const float4*)(a1 + 8);
1415            float4 v3 = *(device const float4*)(a1 + 12);
1416            float4 v4 = *(device const float4*)(a1 + 16);
1417            float4 v5 = *(device const float4*)(a1 + 20);
1418            float4 v6 = *(device const float4*)(a1 + 24);
1419            float4 v7 = *(device const float4*)(a1 + 28);
1420            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1421                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1422            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1423                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1424            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1425                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1426            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1427                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1428            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1429        }
1430        // Token 2
1431        if (tok_count > 2) {
1432            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1433            float4 v0 = *(device const float4*)(a2);
1434            float4 v1 = *(device const float4*)(a2 + 4);
1435            float4 v2 = *(device const float4*)(a2 + 8);
1436            float4 v3 = *(device const float4*)(a2 + 12);
1437            float4 v4 = *(device const float4*)(a2 + 16);
1438            float4 v5 = *(device const float4*)(a2 + 20);
1439            float4 v6 = *(device const float4*)(a2 + 24);
1440            float4 v7 = *(device const float4*)(a2 + 28);
1441            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1442                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1443            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1444                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1445            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1446                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1447            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1448                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1449            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1450        }
1451        // Token 3
1452        if (tok_count > 3) {
1453            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1454            float4 v0 = *(device const float4*)(a3);
1455            float4 v1 = *(device const float4*)(a3 + 4);
1456            float4 v2 = *(device const float4*)(a3 + 8);
1457            float4 v3 = *(device const float4*)(a3 + 12);
1458            float4 v4 = *(device const float4*)(a3 + 16);
1459            float4 v5 = *(device const float4*)(a3 + 20);
1460            float4 v6 = *(device const float4*)(a3 + 24);
1461            float4 v7 = *(device const float4*)(a3 + 28);
1462            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1463                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1464            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1465                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1466            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1467                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1468            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1469                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1470            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1471        }
1472    }
1473
1474    // simdgroup reduction
1475    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1476    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1477    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1478    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1479
1480    if (simd_lane == 0) {
1481        device float* o0 = outputs + (tok_base + 0) * rows;
1482        if (row_base     < rows) o0[row_base]     = s00;
1483        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1484        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1485        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1486
1487        if (tok_count > 1) {
1488            device float* o1 = outputs + (tok_base + 1) * rows;
1489            if (row_base     < rows) o1[row_base]     = s10;
1490            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1491            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1492            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1493        }
1494        if (tok_count > 2) {
1495            device float* o2 = outputs + (tok_base + 2) * rows;
1496            if (row_base     < rows) o2[row_base]     = s20;
1497            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1498            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1499            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1500        }
1501        if (tok_count > 3) {
1502            device float* o3 = outputs + (tok_base + 3) * rows;
1503            if (row_base     < rows) o3[row_base]     = s30;
1504            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1505            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1506            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1507        }
1508    }
1509}
1510
1511// ── matmul_q8_mma ──────────────────────────────────────────────────────
1512// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1513// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1514// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1515// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1516//
1517// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1518// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1519// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1520// dequantized into threadgroup memory once per block and reused by all
1521// simdgroups in the tile.
1522//
1523// Assumptions (verified in the dispatch helper, falls back otherwise):
1524//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1525//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1526//   * num_tokens may be any value; partial row at the tile boundary is handled
1527//     via a scratch copy path.
1528constant constexpr uint MMA_TOK_TILE = 16;
1529constant constexpr uint MMA_ROW_TILE = 16;
1530
1531kernel void matmul_q8_mma(
1532    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1533    device const float* inputs   [[buffer(1)]],  // [M, cols]
1534    device float* outputs        [[buffer(2)]],  // [M, rows]
1535    constant uint& num_tokens    [[buffer(3)]],
1536    constant uint& rows          [[buffer(4)]],
1537    constant uint& cols          [[buffer(5)]],
1538    uint2 tgid [[threadgroup_position_in_grid]],
1539    uint tid [[thread_index_in_threadgroup]],
1540    uint simd_id [[simdgroup_index_in_threadgroup]])
1541{
1542    uint row_base = tgid.x * MMA_ROW_TILE;
1543    uint tok_base = tgid.y * MMA_TOK_TILE;
1544    if (row_base >= rows || tok_base >= num_tokens) return;
1545
1546    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1547    threadgroup float w_tile[MMA_ROW_TILE * 32];
1548    threadgroup float t_tile[MMA_TOK_TILE * 32];
1549
1550    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1551    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1552    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1553
1554    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1555
1556    uint blocks_per_row = cols / 32;
1557    uint row_bytes = blocks_per_row * 34;
1558
1559    for (uint blk = 0; blk < blocks_per_row; blk++) {
1560        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1561        // 512 floats / 128 threads = 4 floats per thread.
1562        {
1563            uint base = tid * 4;
1564            for (uint ii = 0; ii < 4; ii++) {
1565                uint idx = base + ii;
1566                uint r = idx / 32;
1567                uint k = idx % 32;
1568                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1569                float sc = float(*(device const half*)rp);
1570                int ival = int(*(device const int8_t*)(rp + 2 + k));
1571                w_tile[r * 32 + k] = float(ival) * sc;
1572            }
1573        }
1574
1575        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1576        {
1577            uint base = tid * 4;
1578            for (uint ii = 0; ii < 4; ii++) {
1579                uint idx = base + ii;
1580                uint m = idx / 32;
1581                uint k = idx % 32;
1582                uint tok = tok_base + m;
1583                t_tile[m * 32 + k] = (tok < num_tokens)
1584                    ? inputs[tok * cols + blk * 32 + k]
1585                    : 0.0f;
1586            }
1587        }
1588
1589        threadgroup_barrier(mem_flags::mem_threadgroup);
1590
1591        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1592        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1593        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1594        // C[m, r] += A[m, k] * B[k, r]
1595        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1596            simdgroup_matrix<float, 8, 8> A, B;
1597            simdgroup_load(A,
1598                t_tile + sg_tok_base * 32 + k_sub * 8,
1599                32,
1600                ulong2(0, 0),
1601                false);
1602            simdgroup_load(B,
1603                w_tile + sg_row_base * 32 + k_sub * 8,
1604                32,
1605                ulong2(0, 0),
1606                true);
1607            simdgroup_multiply_accumulate(C, A, B, C);
1608        }
1609
1610        threadgroup_barrier(mem_flags::mem_threadgroup);
1611    }
1612
1613    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1614    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1615    uint out_tok = tok_base + sg_tok_base;
1616    uint out_row = row_base + sg_row_base;
1617    bool full_tok = (out_tok + 8 <= num_tokens);
1618    if (full_tok) {
1619        // Fast path: entire 8×8 sub-tile is in-bounds.
1620        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1621    } else if (out_tok < num_tokens) {
1622        // Partial row at the last token tile: stage in per-simdgroup scratch
1623        // and scalar-copy the valid rows.
1624        threadgroup float scratch[4 * 64];
1625        simdgroup_store(C, scratch + simd_id * 64, 8);
1626        simdgroup_barrier(mem_flags::mem_threadgroup);
1627        uint lane = tid % 32;
1628        if (lane == 0) {
1629            uint valid = num_tokens - out_tok;  // 1..7
1630            for (uint m = 0; m < valid; m++) {
1631                device float* dst = outputs + (out_tok + m) * rows + out_row;
1632                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1633                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1634                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1635            }
1636        }
1637    }
1638}
1639
1640// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1641// Larger-tile variant of matmul_q8_mma for long-context prefill.
1642//
1643// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1644// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1645// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1646//
1647//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1648//     output sub-tiles (tok, row):
1649//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1650//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1651//
1652// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1653// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1654// kernel — and halves the number of threadgroups vs the 16×16 tile.
1655//
1656// Assumptions (verified in dispatch helper, fallback otherwise):
1657//   * cols % 32 == 0
1658//   * rows % 32 == 0
1659constant constexpr uint MMA32_TOK_TILE = 32;
1660constant constexpr uint MMA32_ROW_TILE = 32;
1661
1662kernel void matmul_q8_mma32(
1663    device const uchar* matrix   [[buffer(0)]],
1664    device const float* inputs   [[buffer(1)]],
1665    device float* outputs        [[buffer(2)]],
1666    constant uint& num_tokens    [[buffer(3)]],
1667    constant uint& rows          [[buffer(4)]],
1668    constant uint& cols          [[buffer(5)]],
1669    uint2 tgid [[threadgroup_position_in_grid]],
1670    uint tid [[thread_index_in_threadgroup]],
1671    uint simd_id [[simdgroup_index_in_threadgroup]])
1672{
1673    uint row_base = tgid.x * MMA32_ROW_TILE;
1674    uint tok_base = tgid.y * MMA32_TOK_TILE;
1675    if (row_base >= rows || tok_base >= num_tokens) return;
1676
1677    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1678    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1679    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1680
1681    uint sg_tok_idx  = simd_id / 2;      // 0..3
1682    uint sg_row_half = simd_id % 2;      // 0..1
1683    uint sg_tok_base = sg_tok_idx * 8;
1684    uint sg_row_base_a = sg_row_half * 16 + 0;
1685    uint sg_row_base_b = sg_row_half * 16 + 8;
1686
1687    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1688    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1689
1690    uint blocks_per_row = cols / 32;
1691    uint row_bytes = blocks_per_row * 34;
1692
1693    for (uint blk = 0; blk < blocks_per_row; blk++) {
1694        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1695        {
1696            uint base = tid * 4;
1697            for (uint ii = 0; ii < 4; ii++) {
1698                uint idx = base + ii;
1699                uint r = idx / 32;
1700                uint k = idx % 32;
1701                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1702                float sc = float(*(device const half*)rp);
1703                int ival = int(*(device const int8_t*)(rp + 2 + k));
1704                w_tile[r * 32 + k] = float(ival) * sc;
1705            }
1706        }
1707
1708        // Cooperative token tile load.
1709        {
1710            uint base = tid * 4;
1711            for (uint ii = 0; ii < 4; ii++) {
1712                uint idx = base + ii;
1713                uint m = idx / 32;
1714                uint k = idx % 32;
1715                uint tok = tok_base + m;
1716                t_tile[m * 32 + k] = (tok < num_tokens)
1717                    ? inputs[tok * cols + blk * 32 + k]
1718                    : 0.0f;
1719            }
1720        }
1721
1722        threadgroup_barrier(mem_flags::mem_threadgroup);
1723
1724        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1725        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1726            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1727            simdgroup_load(A,
1728                t_tile + sg_tok_base * 32 + k_sub * 8,
1729                32,
1730                ulong2(0, 0),
1731                false);
1732            simdgroup_load(B_a,
1733                w_tile + sg_row_base_a * 32 + k_sub * 8,
1734                32,
1735                ulong2(0, 0),
1736                true);
1737            simdgroup_load(B_b,
1738                w_tile + sg_row_base_b * 32 + k_sub * 8,
1739                32,
1740                ulong2(0, 0),
1741                true);
1742            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1743            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1744        }
1745
1746        threadgroup_barrier(mem_flags::mem_threadgroup);
1747    }
1748
1749    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1750    // (verified in dispatch), so full simdgroup_store is safe for the row
1751    // dimension; only the last token tile may be partial.
1752    uint out_tok = tok_base + sg_tok_base;
1753    uint out_row_a = row_base + sg_row_base_a;
1754    uint out_row_b = row_base + sg_row_base_b;
1755    bool full_tok = (out_tok + 8 <= num_tokens);
1756    if (full_tok) {
1757        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1758        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1759    } else if (out_tok < num_tokens) {
1760        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1761        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1762        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1763        simdgroup_barrier(mem_flags::mem_threadgroup);
1764        uint lane = tid % 32;
1765        if (lane == 0) {
1766            uint valid = num_tokens - out_tok;  // 1..7
1767            for (uint m = 0; m < valid; m++) {
1768                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1769                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1770                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1771                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1772                for (uint j = 0; j < 8; j++) {
1773                    dst_a[j] = src_a[j];
1774                    dst_b[j] = src_b[j];
1775                }
1776            }
1777        }
1778    }
1779}
1780
1781// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1782// FP16 threadgroup-tile variant of matmul_q8_mma32.
1783//
1784// Stores dequantized weights and token inputs as `half` in threadgroup
1785// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1786// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1787// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1788// representation preserves the full quantized dynamic range.  Token
1789// activations stay numerically safe because the subsequent
1790// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1791//
1792// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1793// accumulators each.  Primary win vs mma32 is occupancy at moderate
1794// prefill lengths where the GPU is wave-starved.
1795kernel void matmul_q8_mma32_h(
1796    device const uchar* matrix   [[buffer(0)]],
1797    device const float* inputs   [[buffer(1)]],
1798    device float* outputs        [[buffer(2)]],
1799    constant uint& num_tokens    [[buffer(3)]],
1800    constant uint& rows          [[buffer(4)]],
1801    constant uint& cols          [[buffer(5)]],
1802    uint2 tgid [[threadgroup_position_in_grid]],
1803    uint tid [[thread_index_in_threadgroup]],
1804    uint simd_id [[simdgroup_index_in_threadgroup]])
1805{
1806    uint row_base = tgid.x * MMA32_ROW_TILE;
1807    uint tok_base = tgid.y * MMA32_TOK_TILE;
1808    if (row_base >= rows || tok_base >= num_tokens) return;
1809
1810    // 32×32 half tiles — 2 KB each, 4 KB total.
1811    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1812    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1813
1814    uint sg_tok_idx  = simd_id / 2;
1815    uint sg_row_half = simd_id % 2;
1816    uint sg_tok_base = sg_tok_idx * 8;
1817    uint sg_row_base_a = sg_row_half * 16 + 0;
1818    uint sg_row_base_b = sg_row_half * 16 + 8;
1819
1820    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1821    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1822
1823    uint blocks_per_row = cols / 32;
1824    uint row_bytes = blocks_per_row * 34;
1825
1826    for (uint blk = 0; blk < blocks_per_row; blk++) {
1827        // Cooperative weight dequantization to FP16.
1828        {
1829            uint base = tid * 4;
1830            for (uint ii = 0; ii < 4; ii++) {
1831                uint idx = base + ii;
1832                uint r = idx / 32;
1833                uint k = idx % 32;
1834                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1835                float sc = float(*(device const half*)rp);
1836                int ival = int(*(device const int8_t*)(rp + 2 + k));
1837                w_tile[r * 32 + k] = half(float(ival) * sc);
1838            }
1839        }
1840
1841        // Cooperative token tile load (f32 → f16 narrowing).
1842        {
1843            uint base = tid * 4;
1844            for (uint ii = 0; ii < 4; ii++) {
1845                uint idx = base + ii;
1846                uint m = idx / 32;
1847                uint k = idx % 32;
1848                uint tok = tok_base + m;
1849                t_tile[m * 32 + k] = (tok < num_tokens)
1850                    ? half(inputs[tok * cols + blk * 32 + k])
1851                    : half(0);
1852            }
1853        }
1854
1855        threadgroup_barrier(mem_flags::mem_threadgroup);
1856
1857        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1858            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1859            simdgroup_load(A,
1860                t_tile + sg_tok_base * 32 + k_sub * 8,
1861                32,
1862                ulong2(0, 0),
1863                false);
1864            simdgroup_load(B_a,
1865                w_tile + sg_row_base_a * 32 + k_sub * 8,
1866                32,
1867                ulong2(0, 0),
1868                true);
1869            simdgroup_load(B_b,
1870                w_tile + sg_row_base_b * 32 + k_sub * 8,
1871                32,
1872                ulong2(0, 0),
1873                true);
1874            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1875            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1876        }
1877
1878        threadgroup_barrier(mem_flags::mem_threadgroup);
1879    }
1880
1881    uint out_tok = tok_base + sg_tok_base;
1882    uint out_row_a = row_base + sg_row_base_a;
1883    uint out_row_b = row_base + sg_row_base_b;
1884    bool full_tok = (out_tok + 8 <= num_tokens);
1885    if (full_tok) {
1886        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1887        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1888    } else if (out_tok < num_tokens) {
1889        threadgroup float scratch[8 * 2 * 64];
1890        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1891        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1892        simdgroup_barrier(mem_flags::mem_threadgroup);
1893        uint lane = tid % 32;
1894        if (lane == 0) {
1895            uint valid = num_tokens - out_tok;
1896            for (uint m = 0; m < valid; m++) {
1897                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1898                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1899                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1900                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1901                for (uint j = 0; j < 8; j++) {
1902                    dst_a[j] = src_a[j];
1903                    dst_b[j] = src_b[j];
1904                }
1905            }
1906        }
1907    }
1908}
1909
1910// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1911// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1912//
1913// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1914// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1915//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1916//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1917// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1918//
1919// Per K_sub iteration: load two A fragments and two B fragments, then run
1920// **four** MMA instructions reusing A_top with both B's and A_bot with
1921// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1922// 2-accumulator kernel and halves the thread count per threadgroup (128
1923// threads), which often improves occupancy on Apple GPUs where the
1924// concurrent-thread budget is the tighter limit than shared-memory size.
1925kernel void matmul_q8_mma32_h4(
1926    device const uchar* matrix   [[buffer(0)]],
1927    device const float* inputs   [[buffer(1)]],
1928    device float* outputs        [[buffer(2)]],
1929    constant uint& num_tokens    [[buffer(3)]],
1930    constant uint& rows          [[buffer(4)]],
1931    constant uint& cols          [[buffer(5)]],
1932    uint2 tgid [[threadgroup_position_in_grid]],
1933    uint tid [[thread_index_in_threadgroup]],
1934    uint simd_id [[simdgroup_index_in_threadgroup]])
1935{
1936    uint row_base = tgid.x * MMA32_ROW_TILE;
1937    uint tok_base = tgid.y * MMA32_TOK_TILE;
1938    if (row_base >= rows || tok_base >= num_tokens) return;
1939
1940    // 32×32 FP16 tiles, 4 KB total.
1941    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1942    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1943
1944    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1945    uint sg_tok_q = simd_id / 2;   // 0..1
1946    uint sg_row_q = simd_id % 2;   // 0..1
1947    uint sg_tok_base = sg_tok_q * 16;
1948    uint sg_row_base = sg_row_q * 16;
1949
1950    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1951    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1952    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1953    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1954
1955    uint blocks_per_row = cols / 32;
1956    uint row_bytes = blocks_per_row * 34;
1957
1958    for (uint blk = 0; blk < blocks_per_row; blk++) {
1959        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1960        {
1961            uint base = tid * 8;
1962            for (uint ii = 0; ii < 8; ii++) {
1963                uint idx = base + ii;
1964                uint r = idx / 32;
1965                uint k = idx % 32;
1966                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1967                float sc = float(*(device const half*)rp);
1968                int ival = int(*(device const int8_t*)(rp + 2 + k));
1969                w_tile[r * 32 + k] = half(float(ival) * sc);
1970            }
1971        }
1972
1973        // Cooperative token tile load.
1974        {
1975            uint base = tid * 8;
1976            for (uint ii = 0; ii < 8; ii++) {
1977                uint idx = base + ii;
1978                uint m = idx / 32;
1979                uint k = idx % 32;
1980                uint tok = tok_base + m;
1981                t_tile[m * 32 + k] = (tok < num_tokens)
1982                    ? half(inputs[tok * cols + blk * 32 + k])
1983                    : half(0);
1984            }
1985        }
1986
1987        threadgroup_barrier(mem_flags::mem_threadgroup);
1988
1989        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1990        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1991            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1992            simdgroup_load(A_top,
1993                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1994                32,
1995                ulong2(0, 0),
1996                false);
1997            simdgroup_load(A_bot,
1998                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1999                32,
2000                ulong2(0, 0),
2001                false);
2002            simdgroup_load(B_lo,
2003                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2004                32,
2005                ulong2(0, 0),
2006                true);
2007            simdgroup_load(B_hi,
2008                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2009                32,
2010                ulong2(0, 0),
2011                true);
2012            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2013            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2014            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2015            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2016        }
2017
2018        threadgroup_barrier(mem_flags::mem_threadgroup);
2019    }
2020
2021    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
2022    uint out_tok_top = tok_base + sg_tok_base + 0;
2023    uint out_tok_bot = tok_base + sg_tok_base + 8;
2024    uint out_row_lo  = row_base + sg_row_base + 0;
2025    uint out_row_hi  = row_base + sg_row_base + 8;
2026    bool full = (out_tok_bot + 8 <= num_tokens);
2027    if (full) {
2028        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
2029        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
2030        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
2031        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
2032    } else {
2033        // Partial-token fallback via per-simdgroup scratch.
2034        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
2035        uint sg_base = simd_id * 256;
2036        simdgroup_store(C_00, scratch + sg_base +   0, 8);
2037        simdgroup_store(C_01, scratch + sg_base +  64, 8);
2038        simdgroup_store(C_10, scratch + sg_base + 128, 8);
2039        simdgroup_store(C_11, scratch + sg_base + 192, 8);
2040        simdgroup_barrier(mem_flags::mem_threadgroup);
2041        uint lane = tid % 32;
2042        if (lane == 0) {
2043            for (uint m = 0; m < 8; m++) {
2044                uint t_top = out_tok_top + m;
2045                if (t_top < num_tokens) {
2046                    device float* dst0 = outputs + t_top * rows + out_row_lo;
2047                    device float* dst1 = outputs + t_top * rows + out_row_hi;
2048                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
2049                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
2050                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2051                }
2052                uint t_bot = out_tok_bot + m;
2053                if (t_bot < num_tokens) {
2054                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2055                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2056                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2057                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2058                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2059                }
2060            }
2061        }
2062    }
2063}
2064
2065// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2066// All-half MMA variant of matmul_q8_mma32_h4.
2067//
2068// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2069// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2070// rate (dual-issue FMA), so if Q8_0 precision holds through half
2071// accumulation this kernel can double the effective FLOP throughput on
2072// matmul-bound prefill.
2073//
2074// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2075// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2076// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2077// precision on extreme values, but the inputs are already quantized so the
2078// per-product error floor is higher than the half-precision rounding error.
2079// We verify correctness on 135M / 1B / 3B before enabling.
2080kernel void matmul_q8_mma32_hh4(
2081    device const uchar* matrix   [[buffer(0)]],
2082    device const float* inputs   [[buffer(1)]],
2083    device float* outputs        [[buffer(2)]],
2084    constant uint& num_tokens    [[buffer(3)]],
2085    constant uint& rows          [[buffer(4)]],
2086    constant uint& cols          [[buffer(5)]],
2087    uint2 tgid [[threadgroup_position_in_grid]],
2088    uint tid [[thread_index_in_threadgroup]],
2089    uint simd_id [[simdgroup_index_in_threadgroup]])
2090{
2091    uint row_base = tgid.x * MMA32_ROW_TILE;
2092    uint tok_base = tgid.y * MMA32_TOK_TILE;
2093    if (row_base >= rows || tok_base >= num_tokens) return;
2094
2095    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2096    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2097
2098    uint sg_tok_q = simd_id / 2;
2099    uint sg_row_q = simd_id % 2;
2100    uint sg_tok_base = sg_tok_q * 16;
2101    uint sg_row_base = sg_row_q * 16;
2102
2103    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2104    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2105    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2106    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2107
2108    uint blocks_per_row = cols / 32;
2109    uint row_bytes = blocks_per_row * 34;
2110
2111    for (uint blk = 0; blk < blocks_per_row; blk++) {
2112        {
2113            uint base = tid * 8;
2114            for (uint ii = 0; ii < 8; ii++) {
2115                uint idx = base + ii;
2116                uint r = idx / 32;
2117                uint k = idx % 32;
2118                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2119                float sc = float(*(device const half*)rp);
2120                int ival = int(*(device const int8_t*)(rp + 2 + k));
2121                w_tile[r * 32 + k] = half(float(ival) * sc);
2122            }
2123        }
2124        {
2125            uint base = tid * 8;
2126            for (uint ii = 0; ii < 8; ii++) {
2127                uint idx = base + ii;
2128                uint m = idx / 32;
2129                uint k = idx % 32;
2130                uint tok = tok_base + m;
2131                t_tile[m * 32 + k] = (tok < num_tokens)
2132                    ? half(inputs[tok * cols + blk * 32 + k])
2133                    : half(0);
2134            }
2135        }
2136        threadgroup_barrier(mem_flags::mem_threadgroup);
2137
2138        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2139            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2140            simdgroup_load(A_top,
2141                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2142                32, ulong2(0, 0), false);
2143            simdgroup_load(A_bot,
2144                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2145                32, ulong2(0, 0), false);
2146            simdgroup_load(B_lo,
2147                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2148                32, ulong2(0, 0), true);
2149            simdgroup_load(B_hi,
2150                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2151                32, ulong2(0, 0), true);
2152            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2153            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2154            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2155            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2156        }
2157
2158        threadgroup_barrier(mem_flags::mem_threadgroup);
2159    }
2160
2161    // Store half accumulators via scratch (must widen to f32 for device output).
2162    uint out_tok_top = tok_base + sg_tok_base + 0;
2163    uint out_tok_bot = tok_base + sg_tok_base + 8;
2164    uint out_row_lo  = row_base + sg_row_base + 0;
2165    uint out_row_hi  = row_base + sg_row_base + 8;
2166
2167    threadgroup half scratch[4 * 4 * 64];
2168    uint sg_base = simd_id * 256;
2169    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2170    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2171    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2172    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2173    simdgroup_barrier(mem_flags::mem_threadgroup);
2174    uint lane = tid % 32;
2175    if (lane == 0) {
2176        for (uint m = 0; m < 8; m++) {
2177            uint t_top = out_tok_top + m;
2178            if (t_top < num_tokens) {
2179                device float* dst0 = outputs + t_top * rows + out_row_lo;
2180                device float* dst1 = outputs + t_top * rows + out_row_hi;
2181                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2182                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2183                for (uint j = 0; j < 8; j++) {
2184                    dst0[j] = float(src0[j]);
2185                    dst1[j] = float(src1[j]);
2186                }
2187            }
2188            uint t_bot = out_tok_bot + m;
2189            if (t_bot < num_tokens) {
2190                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2191                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2192                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2193                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2194                for (uint j = 0; j < 8; j++) {
2195                    dst2[j] = float(src2[j]);
2196                    dst3[j] = float(src3[j]);
2197                }
2198            }
2199        }
2200    }
2201}
2202
2203// ── add_bias_batch ─────────────────────────────────────────────────────
2204// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2205// Used for Qwen2 QKV bias after the fused qkv matmul.
2206//     out[token, i] += bias[i]    for i in 0..rows, token in 0..num_tokens
2207kernel void add_bias_batch(
2208    device float* out            [[buffer(0)]],  // [num_tokens, rows]
2209    device const float* bias     [[buffer(1)]],  // [rows]
2210    constant uint& num_tokens    [[buffer(2)]],
2211    constant uint& rows          [[buffer(3)]],
2212    uint id [[thread_position_in_grid]])
2213{
2214    uint total = num_tokens * rows;
2215    if (id >= total) return;
2216    uint i = id % rows;
2217    out[id] += bias[i];
2218}
2219
2220// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2221// Batched Q4_0 matrix-vector multiply for M input vectors.
2222// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2223kernel void matmul_vec_q4_batch(
2224    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2225    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2226    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2227    constant uint& num_tokens    [[buffer(3)]],  // M
2228    constant uint& rows          [[buffer(4)]],
2229    constant uint& cols          [[buffer(5)]],
2230    uint tgid [[threadgroup_position_in_grid]],
2231    uint tid [[thread_index_in_threadgroup]],
2232    uint simd_lane [[thread_index_in_simdgroup]],
2233    uint simd_id [[simdgroup_index_in_threadgroup]])
2234{
2235    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2236    uint token = tgid / row_tgs;
2237    uint tg_in_token = tgid % row_tgs;
2238    if (token >= num_tokens) return;
2239
2240    threadgroup float vec_tile[VEC_TILE_SIZE];
2241    device const float* input = inputs + token * cols;
2242    for (uint i = tid; i < cols; i += 256) {
2243        vec_tile[i] = input[i];
2244    }
2245    threadgroup_barrier(mem_flags::mem_threadgroup);
2246
2247    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2248    if (row_base >= rows) return;
2249
2250    uint blocks_per_row = cols / 32;
2251    uint row_bytes = blocks_per_row * 18;
2252
2253    device const uchar* r0 = matrix + row_base * row_bytes;
2254    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2255    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2256    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2257
2258    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2259
2260    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2261        uint bb = blk * 18;
2262        uint vb = blk * 32;
2263
2264        float sc0 = float(*(device const half*)(r0 + bb));
2265        float sc1 = float(*(device const half*)(r1 + bb));
2266        float sc2 = float(*(device const half*)(r2 + bb));
2267        float sc3 = float(*(device const half*)(r3 + bb));
2268
2269        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2270        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2271        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2272        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2273
2274        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2275        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2276        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2277        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2278        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2279        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2280        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2281        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2282
2283        float bd0=0, bd1=0, bd2=0, bd3=0;
2284        uchar4 b;
2285
2286        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2287        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2288        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2289        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2290
2291        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2292        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2293        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2294        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2295
2296        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2297        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2298        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2299        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2300
2301        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2302        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2303        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2304        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2305
2306        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2307    }
2308
2309    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2310    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2311
2312    device float* output = outputs + token * rows;
2313    if (simd_lane == 0) {
2314        if (row_base     < rows) output[row_base]     = sum0;
2315        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2316        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2317        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2318    }
2319}
2320
2321// ── copy_kv_batch ─────────────────────────────────────────────────────
2322// Copy K or V from a strided batch QKV buffer to the KV cache.
2323// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2324// dst layout: contiguous [max_seq, kv_dim] cache.
2325kernel void copy_kv_batch(
2326    device const float* src  [[buffer(0)]],  // batch QKV buffer (f32)
2327    device half* dst         [[buffer(1)]],  // KV cache (f16)
2328    constant uint& M         [[buffer(2)]],  // num tokens in batch
2329    constant uint& kv_dim    [[buffer(3)]],  // elements per KV vector
2330    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2331    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2332    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2333    uint id [[thread_position_in_grid]])
2334{
2335    uint total = M * kv_dim;
2336    if (id >= total) return;
2337    uint token = id / kv_dim;
2338    uint d = id % kv_dim;
2339    uint dst_off = (base_pos + token) * kv_dim + d;
2340    uint src_off = token * src_stride + src_offset + d;
2341    dst[dst_off] = half(src[src_off]);
2342}
2343
2344// ── attention_batch ───────────────────────────────────────────────────
2345// Batched causal attention for prefill. Processes M tokens in one dispatch.
2346// Each threadgroup handles one (token, head) pair.
2347// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2348// Causal masking: token i can only attend to positions 0..base_pos+i.
2349kernel void attention_batch(
2350    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2351    device const half*  k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim] f16
2352    device const half*  v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim] f16
2353    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2354    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2355    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2356    constant uint& num_heads         [[buffer(6)]],
2357    constant uint& num_kv_heads      [[buffer(7)]],
2358    constant uint& head_dim          [[buffer(8)]],
2359    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2360    uint tgid [[threadgroup_position_in_grid]],
2361    uint tid [[thread_index_in_threadgroup]],
2362    uint simd_lane [[thread_index_in_simdgroup]],
2363    uint simd_id [[simdgroup_index_in_threadgroup]])
2364{
2365    // Grid: M * num_heads threadgroups
2366    uint token_idx = tgid / num_heads;
2367    uint head = tgid % num_heads;
2368    if (token_idx >= M) return;
2369
2370    uint kv_head = head / (num_heads / num_kv_heads);
2371    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2372
2373    // Q offset uses strided layout (from batch QKV buffer)
2374    uint q_off = token_idx * q_stride + head * head_dim;
2375    // Output is contiguous [M, num_heads * head_dim]
2376    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2377
2378    // Shared memory for attention scores — sized to the effective max_seq_len
2379    // (4096 for all supported models) so long-context attention doesn't overflow.
2380    threadgroup float scores[ATTN_SCORES_SIZE];
2381
2382    // Step 1: Q * K^T with simdgroup reduction
2383    for (uint s = simd_id; s < seq_len; s += 8) {
2384        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2385        float dot = 0.0;
2386        for (uint d = simd_lane; d < head_dim; d += 32) {
2387            dot += q_batch[q_off + d] * k_cache[k_off + d];
2388        }
2389        dot = simd_sum(dot);
2390        if (simd_lane == 0) {
2391            scores[s] = dot * fast::rsqrt(float(head_dim));
2392        }
2393    }
2394    threadgroup_barrier(mem_flags::mem_threadgroup);
2395
2396    // Step 2: Softmax (cooperative)
2397    float local_max = -INFINITY;
2398    for (uint s = tid; s < seq_len; s += 256) {
2399        local_max = max(local_max, scores[s]);
2400    }
2401    local_max = simd_max(local_max);
2402    threadgroup float shared_max[8];
2403    if (simd_lane == 0) shared_max[simd_id] = local_max;
2404    threadgroup_barrier(mem_flags::mem_threadgroup);
2405    if (tid == 0) {
2406        float m = shared_max[0];
2407        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2408        shared_max[0] = m;
2409    }
2410    threadgroup_barrier(mem_flags::mem_threadgroup);
2411    float max_val = shared_max[0];
2412
2413    float local_sum = 0.0;
2414    for (uint s = tid; s < seq_len; s += 256) {
2415        scores[s] = fast::exp(scores[s] - max_val);
2416        local_sum += scores[s];
2417    }
2418    local_sum = simd_sum(local_sum);
2419    threadgroup float shared_sum[8];
2420    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2421    threadgroup_barrier(mem_flags::mem_threadgroup);
2422    if (tid == 0) {
2423        float total = 0.0;
2424        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2425        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2426    }
2427    threadgroup_barrier(mem_flags::mem_threadgroup);
2428    float inv_sum = shared_sum[0];
2429    for (uint s = tid; s < seq_len; s += 256) {
2430        scores[s] *= inv_sum;
2431    }
2432    threadgroup_barrier(mem_flags::mem_threadgroup);
2433
2434    // Step 3: scores * V using float4 vectorized loads
2435    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2436    // This is much better than the scalar version where only 64 of 256 threads are active.
2437    uint v_stride = num_kv_heads * head_dim;
2438    uint head_dim4 = head_dim / 4;
2439    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2440        uint d = d4 * 4;
2441        float4 acc = float4(0.0);
2442        uint v_base = kv_head * head_dim + d;
2443        uint seq_len4 = seq_len & ~3u;
2444        for (uint s = 0; s < seq_len4; s += 4) {
2445            float sc0 = scores[s];
2446            float sc1 = scores[s + 1];
2447            float sc2 = scores[s + 2];
2448            float sc3 = scores[s + 3];
2449            acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
2450                 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
2451                 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
2452                 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
2453        }
2454        for (uint s = seq_len4; s < seq_len; s++) {
2455            acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
2456        }
2457        *(device float4*)(output_batch + out_off + d) = acc;
2458    }
2459    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2460    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2461        float acc = 0.0;
2462        uint v_base = kv_head * head_dim + d;
2463        for (uint s = 0; s < seq_len; s++) {
2464            acc += scores[s] * v_cache[s * v_stride + v_base];
2465        }
2466        output_batch[out_off + d] = acc;
2467    }
2468}
2469
2470// ── attention_flash_batch ─────────────────────────────────────────────
2471// Streaming attention with online softmax.  Same grid as attention_batch
2472// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2473// matrix is never materialized.  K/V positions are processed in a tile of
2474// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2475// the standard flash-attention recurrence:
2476//
2477//     m_new   = max(m_old, tile_max)
2478//     alpha   = exp(m_old - m_new)
2479//     l_new   = alpha * l_old + sum(exp(S - m_new))
2480//     O_new   = alpha * O_old + sum(exp(S - m_new) * V)
2481//     O_final = O / l
2482//
2483// This removes the `scores[2048]` cap in attention_batch (which silently
2484// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2485// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2486//
2487// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2488constant constexpr uint FLASH_K_TILE = 32;
2489constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2490
2491kernel void attention_flash_batch(
2492    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2493    device const half*  k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim] f16
2494    device const half*  v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim] f16
2495    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2496    constant uint& M                 [[buffer(4)]],
2497    constant uint& base_pos          [[buffer(5)]],
2498    constant uint& num_heads         [[buffer(6)]],
2499    constant uint& num_kv_heads      [[buffer(7)]],
2500    constant uint& head_dim          [[buffer(8)]],
2501    constant uint& q_stride          [[buffer(9)]],
2502    uint tgid [[threadgroup_position_in_grid]],
2503    uint tid [[thread_index_in_threadgroup]],
2504    uint simd_lane [[thread_index_in_simdgroup]],
2505    uint simd_id [[simdgroup_index_in_threadgroup]])
2506{
2507    uint token_idx = tgid / num_heads;
2508    uint head = tgid % num_heads;
2509    if (token_idx >= M) return;
2510
2511    uint kv_head = head / (num_heads / num_kv_heads);
2512    uint seq_len = base_pos + token_idx + 1;  // causal: attend to [0, base_pos + token_idx]
2513    uint q_off = token_idx * q_stride + head * head_dim;
2514    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2515
2516    // Threadgroup state:
2517    //   q_sh:      Q vector for this (token, head), loaded once
2518    //   o_sh:      running output vector, updated each K tile
2519    //   scores_sh: scores for the current K tile only
2520    //   stats:     [running max, running sum]  (see flash-attention recurrence)
2521    threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2522    threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2523    threadgroup float scores_sh[FLASH_K_TILE];
2524    threadgroup float stats[2];
2525    threadgroup float sg_scratch[8];  // simdgroup-level reduction buffer
2526
2527    // --- Load Q (one row) and zero the running O ---
2528    for (uint d = tid; d < head_dim; d += 256) {
2529        q_sh[d] = q_batch[q_off + d];
2530        o_sh[d] = 0.0f;
2531    }
2532    if (tid == 0) {
2533        stats[0] = -INFINITY;
2534        stats[1] = 0.0f;
2535    }
2536    threadgroup_barrier(mem_flags::mem_threadgroup);
2537
2538    float scale = fast::rsqrt(float(head_dim));
2539    uint v_stride = num_kv_heads * head_dim;
2540    uint v_base = kv_head * head_dim;
2541
2542    // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2543    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2544        uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2545
2546        // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2547        // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2548        for (uint ti = simd_id; ti < tile_n; ti += 8) {
2549            uint k_off = (kv_base + ti) * v_stride + v_base;  // same layout as V stride
2550            float dot = 0.0f;
2551            for (uint d = simd_lane; d < head_dim; d += 32) {
2552                dot += q_sh[d] * k_cache[k_off + d];
2553            }
2554            dot = simd_sum(dot);
2555            if (simd_lane == 0) {
2556                scores_sh[ti] = dot * scale;
2557            }
2558        }
2559        threadgroup_barrier(mem_flags::mem_threadgroup);
2560
2561        // [2] Tile max via cooperative reduction.
2562        float local_max = -INFINITY;
2563        for (uint s = tid; s < tile_n; s += 256) {
2564            local_max = max(local_max, scores_sh[s]);
2565        }
2566        local_max = simd_max(local_max);
2567        if (simd_lane == 0) {
2568            sg_scratch[simd_id] = local_max;
2569        }
2570        threadgroup_barrier(mem_flags::mem_threadgroup);
2571        // [3] Merge with running max, compute alpha, rescale running l.
2572        float m_new;
2573        float alpha;
2574        if (tid == 0) {
2575            float tile_max = sg_scratch[0];
2576            for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2577            float m_old = stats[0];
2578            m_new = max(m_old, tile_max);
2579            // First iteration: m_old = -inf → alpha = 0 (reset O).
2580            alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2581            stats[0] = m_new;
2582            stats[1] *= alpha;
2583            // Broadcast via sg_scratch.
2584            sg_scratch[0] = alpha;
2585            sg_scratch[1] = m_new;
2586        }
2587        threadgroup_barrier(mem_flags::mem_threadgroup);
2588        alpha = sg_scratch[0];
2589        m_new = sg_scratch[1];
2590
2591        // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2592        for (uint d = tid; d < head_dim; d += 256) {
2593            o_sh[d] *= alpha;
2594        }
2595        for (uint s = tid; s < tile_n; s += 256) {
2596            scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2597        }
2598        threadgroup_barrier(mem_flags::mem_threadgroup);
2599
2600        // [5] Tile sum → update running l.
2601        float local_sum = 0.0f;
2602        for (uint s = tid; s < tile_n; s += 256) {
2603            local_sum += scores_sh[s];
2604        }
2605        local_sum = simd_sum(local_sum);
2606        if (simd_lane == 0) {
2607            sg_scratch[simd_id] = local_sum;
2608        }
2609        threadgroup_barrier(mem_flags::mem_threadgroup);
2610        if (tid == 0) {
2611            float tile_sum = 0.0f;
2612            for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2613            stats[1] += tile_sum;
2614        }
2615        threadgroup_barrier(mem_flags::mem_threadgroup);
2616
2617        // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2618        for (uint d = tid; d < head_dim; d += 256) {
2619            float acc = 0.0f;
2620            for (uint s = 0; s < tile_n; s++) {
2621                acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2622            }
2623            o_sh[d] += acc;
2624        }
2625        threadgroup_barrier(mem_flags::mem_threadgroup);
2626    }
2627
2628    // --- Normalize and write output ---
2629    float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2630    for (uint d = tid; d < head_dim; d += 256) {
2631        output_batch[out_off + d] = o_sh[d] * inv_l;
2632    }
2633}
2634
2635// ── attention_mma_flash_batch ─────────────────────────────────────────
2636// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2637// both Q·K^T and P·V.  Processes Q_BLOCK=8 tokens of one head per
2638// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2639// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2640// arithmetic.
2641//
2642// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2643// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2644// to attention_batch / attention_flash_batch otherwise.
2645//
2646// Online softmax recurrence is identical to attention_flash_batch but
2647// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2648constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2649constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2650constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2651
2652kernel void attention_mma_flash_batch(
2653    device const float* q_batch      [[buffer(0)]],
2654    device const half*  k_cache      [[buffer(1)]],
2655    device const half*  v_cache      [[buffer(2)]],
2656    device float* output_batch       [[buffer(3)]],
2657    constant uint& M                 [[buffer(4)]],
2658    constant uint& base_pos          [[buffer(5)]],
2659    constant uint& num_heads         [[buffer(6)]],
2660    constant uint& num_kv_heads      [[buffer(7)]],
2661    constant uint& head_dim          [[buffer(8)]],
2662    constant uint& q_stride          [[buffer(9)]],
2663    uint2 tgid [[threadgroup_position_in_grid]],
2664    uint tid [[thread_index_in_threadgroup]],
2665    uint simd_lane [[thread_index_in_simdgroup]],
2666    uint simd_id [[simdgroup_index_in_threadgroup]])
2667{
2668    uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2669    uint head = tgid.y;
2670    if (q_block_start >= M) return;
2671    uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2672
2673    uint kv_head = head / (num_heads / num_kv_heads);
2674    // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2675    // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2676    uint seq_len = base_pos + q_block_start + q_valid;
2677    float scale = fast::rsqrt(float(head_dim));
2678
2679    uint kv_stride = num_kv_heads * head_dim;
2680    uint kv_base_off = kv_head * head_dim;
2681
2682    // ── Threadgroup memory ──
2683    // q_sh:  [Q_BLOCK, head_dim] half  — Q tile, loaded once
2684    // k_sh:  [K_BLOCK, head_dim] half  — K tile, refreshed per kv_base iter
2685    // v_sh:  [K_BLOCK, head_dim] half  — V tile, refreshed per kv_base iter
2686    // s_sh:  [Q_BLOCK, K_BLOCK] float  — raw Q·K^T scores, then scaled+masked
2687    // p_sh:  [Q_BLOCK, K_BLOCK] half   — softmax probabilities (for P·V MMA)
2688    // o_sh:  [Q_BLOCK, head_dim] float — running output accumulator
2689    // m_sh:  [Q_BLOCK] float           — running max per Q row
2690    // l_sh:  [Q_BLOCK] float           — running softmax denominator per Q row
2691    // scratch: 4*Q_BLOCK floats        — per-row reduction scratch
2692    threadgroup half  q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2693    threadgroup half  k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2694    threadgroup half  v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2695    threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2696    threadgroup half  p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2697    threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2698    threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2699    threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2700    threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2701
2702    // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2703    uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2704    for (uint i = tid; i < qblock_elems; i += 128) {
2705        uint q = i / head_dim;
2706        uint d = i % head_dim;
2707        if (q < q_valid) {
2708            uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2709            q_sh[q * head_dim + d] = half(q_batch[q_off]);
2710        } else {
2711            q_sh[q * head_dim + d] = half(0);
2712        }
2713        o_sh[q * head_dim + d] = 0.0f;
2714    }
2715    if (tid < FLASH_MMA_Q_BLOCK) {
2716        m_sh[tid] = -INFINITY;
2717        l_sh[tid] = 0.0f;
2718    }
2719    threadgroup_barrier(mem_flags::mem_threadgroup);
2720
2721    // ── Stream K/V in K_BLOCK chunks ──
2722    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2723        uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2724
2725        // Load K and V tile into TG memory (as half).
2726        uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2727        for (uint i = tid; i < kv_tile_elems; i += 128) {
2728            uint k_pos = i / head_dim;
2729            uint d = i % head_dim;
2730            if (k_pos < tile_n) {
2731                uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2732                k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2733                v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2734            } else {
2735                k_sh[k_pos * head_dim + d] = half(0);
2736                v_sh[k_pos * head_dim + d] = half(0);
2737            }
2738        }
2739        threadgroup_barrier(mem_flags::mem_threadgroup);
2740
2741        // ── Phase 1: S = Q @ K^T via MMA ──
2742        // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2743        // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2744        // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2745        {
2746            simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2747            uint dim_chunks = head_dim / 8;
2748            for (uint dc = 0; dc < dim_chunks; dc++) {
2749                simdgroup_matrix<half, 8, 8> A, B;
2750                // A = Q[0:8, dc*8 : dc*8+8]  (rows of Q, no transpose)
2751                simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2752                // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2753                // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2754                // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2755                // transpose=true, which places it in the register as K^T of that sub-block.
2756                simdgroup_load(B,
2757                    k_sh + (simd_id * 8) * head_dim + dc * 8,
2758                    head_dim, ulong2(0, 0), true);
2759                simdgroup_multiply_accumulate(C, A, B, C);
2760            }
2761            // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2762            simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2763        }
2764        threadgroup_barrier(mem_flags::mem_threadgroup);
2765
2766        // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2767        // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2768        uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2769        for (uint i = tid; i < s_elems; i += 128) {
2770            uint q = i / FLASH_MMA_K_BLOCK;
2771            uint k = i % FLASH_MMA_K_BLOCK;
2772            uint global_q = q_block_start + q;
2773            uint global_kv = kv_base + k;
2774            bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2775            s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2776        }
2777        threadgroup_barrier(mem_flags::mem_threadgroup);
2778
2779        // ── Phase 2b: per-row max via simdgroup reduction ──
2780        // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2781        // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2782        {
2783            uint row_base = simd_id * 2;
2784            for (uint qr = 0; qr < 2; qr++) {
2785                uint q = row_base + qr;
2786                float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2787                float row_max = simd_max(my);
2788                if (simd_lane == 0) {
2789                    scratch[q] = row_max;  // tile_max[q]
2790                }
2791            }
2792        }
2793        threadgroup_barrier(mem_flags::mem_threadgroup);
2794
2795        // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2796        if (tid < FLASH_MMA_Q_BLOCK) {
2797            uint q = tid;
2798            float m_old = m_sh[q];
2799            float tile_max = scratch[q];
2800            float m_new = max(m_old, tile_max);
2801            float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2802            m_sh[q] = m_new;
2803            l_sh[q] = l_sh[q] * alpha;
2804            // scratch[q]               = m_new   (for phase 2d)
2805            // scratch[Q_BLOCK + q]     = alpha   (for phase 3)
2806            scratch[q] = m_new;
2807            scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2808        }
2809        threadgroup_barrier(mem_flags::mem_threadgroup);
2810
2811        // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2812        {
2813            uint row_base = simd_id * 2;
2814            for (uint qr = 0; qr < 2; qr++) {
2815                uint q = row_base + qr;
2816                float m_new = scratch[q];
2817                float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2818                p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2819                float row_sum = simd_sum(p);
2820                if (simd_lane == 0) {
2821                    scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2822                }
2823            }
2824        }
2825        threadgroup_barrier(mem_flags::mem_threadgroup);
2826
2827        // ── Phase 2e: l_sh += tile_sum ──
2828        if (tid < FLASH_MMA_Q_BLOCK) {
2829            uint q = tid;
2830            l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2831        }
2832
2833        // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2834        threadgroup_barrier(mem_flags::mem_threadgroup);
2835        for (uint i = tid; i < qblock_elems; i += 128) {
2836            uint q = i / head_dim;
2837            float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2838            o_sh[i] *= alpha;
2839        }
2840        threadgroup_barrier(mem_flags::mem_threadgroup);
2841
2842        // ── Phase 4: O += P @ V via MMA ──
2843        // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2844        // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2845        // For head_dim=64: 16 dims/sg = 2 tiles.  head_dim=128: 32 dims/sg = 4 tiles.
2846        {
2847            uint dims_per_sg = head_dim / 4;       // 16 or 32
2848            uint tiles_per_sg = dims_per_sg / 8;   // 2 or 4
2849            uint sg_d_base = simd_id * dims_per_sg;
2850            for (uint t = 0; t < tiles_per_sg; t++) {
2851                uint d_base = sg_d_base + t * 8;
2852                simdgroup_matrix<float, 8, 8> O_acc;
2853                simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2854                uint k_chunks = FLASH_MMA_K_BLOCK / 8;  // 4
2855                for (uint kc = 0; kc < k_chunks; kc++) {
2856                    simdgroup_matrix<half, 8, 8> A, B;
2857                    simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2858                                   ulong2(0, 0), false);
2859                    simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2860                                   ulong2(0, 0), false);
2861                    simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2862                }
2863                simdgroup_store(O_acc, o_sh + d_base, head_dim);
2864            }
2865        }
2866        threadgroup_barrier(mem_flags::mem_threadgroup);
2867    }
2868
2869    // ── Finalize: O /= l, write to output ──
2870    for (uint i = tid; i < qblock_elems; i += 128) {
2871        uint q = i / head_dim;
2872        uint d = i % head_dim;
2873        if (q < q_valid) {
2874            float l = l_sh[q];
2875            float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2876            uint token_idx = q_block_start + q;
2877            uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2878            output_batch[out_off] = o_sh[i] * inv_l;
2879        }
2880    }
2881}
2882
2883// ── rope_qk_batch ─────────────────────────────────────────────────────
2884// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2885// launch + memory barrier per layer. Both Q and K live in the same
2886// qkv_data buffer at different offsets within each token's row.
2887// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2888kernel void rope_qk_batch(
2889    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2890    constant uint& M                 [[buffer(1)]],   // num tokens
2891    constant uint& base_pos          [[buffer(2)]],   // starting position
2892    constant uint& num_q_heads       [[buffer(3)]],
2893    constant uint& num_kv_heads      [[buffer(4)]],
2894    constant uint& head_dim          [[buffer(5)]],
2895    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2896    constant float& theta            [[buffer(7)]],
2897    uint id [[thread_position_in_grid]])
2898{
2899    uint half_dim = head_dim / 2;
2900    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2901    uint token = id / total_pairs;
2902    uint pair = id % total_pairs;
2903    if (token >= M) return;
2904
2905    uint pos = base_pos + token;
2906    uint q_pairs = num_q_heads * half_dim;
2907
2908    uint h, i, offset;
2909    if (pair < q_pairs) {
2910        // Q head
2911        h = pair / half_dim;
2912        i = pair % half_dim;
2913        offset = token * qkv_stride + h * head_dim + i * 2;
2914    } else {
2915        // K head
2916        uint kp = pair - q_pairs;
2917        h = kp / half_dim;
2918        i = kp % half_dim;
2919        uint k_start = num_q_heads * head_dim;
2920        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2921    }
2922
2923    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2924    float angle = float(pos) * freq;
2925    float cos_val = cos(angle);
2926    float sin_val = sin(angle);
2927
2928    float x0 = qkv_data[offset];
2929    float x1 = qkv_data[offset + 1];
2930    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2931    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2932}
2933
2934// ── copy_kv_both_batch ────────────────────────────────────────────────
2935// Fused K+V cache copy in a single dispatch: copies both K and V from
2936// the strided batch QKV buffer to their respective KV cache buffers.
2937// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2938kernel void copy_kv_both_batch(
2939    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride] f32
2940    device half* k_dst         [[buffer(1)]],  // K cache [max_seq, kv_dim] f16
2941    device half* v_dst         [[buffer(2)]],  // V cache [max_seq, kv_dim] f16
2942    constant uint& M           [[buffer(3)]],  // num tokens in batch
2943    constant uint& kv_dim      [[buffer(4)]],  // elements per KV vector
2944    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2945    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2946    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2947    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2948    uint id [[thread_position_in_grid]])
2949{
2950    // Total elements = M * kv_dim * 2 (K + V)
2951    uint total_kv = M * kv_dim;
2952    if (id >= total_kv * 2) return;
2953
2954    uint is_v = id / total_kv;        // 0 = K, 1 = V
2955    uint local_id = id % total_kv;
2956    uint token = local_id / kv_dim;
2957    uint d = local_id % kv_dim;
2958
2959    uint dst_off = (base_pos + token) * kv_dim + d;
2960    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2961
2962    if (is_v) {
2963        v_dst[dst_off] = half(src[src_off]);
2964    } else {
2965        k_dst[dst_off] = half(src[src_off]);
2966    }
2967}
2968"#
2969    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2970    .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2971}
2972
2973// ---------------------------------------------------------------------------
2974// model.rs generation
2975// ---------------------------------------------------------------------------
2976
2977fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2978    let mut code = String::with_capacity(48 * 1024);
2979    emit_model_header(&mut code, config)?;
2980    emit_metal_model_struct(&mut code, config)?;
2981    emit_layer_buffers_struct(&mut code, config)?;
2982    emit_metal_model_impl(&mut code, config)?;
2983    emit_helper_functions(&mut code)?;
2984    Ok(code)
2985}
2986
2987fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2988    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2989    writeln!(
2990        code,
2991        "//! Model: {} ({} layers, hidden={})",
2992        config.architecture, config.num_layers, config.hidden_size
2993    )?;
2994    writeln!(code, "//!")?;
2995    writeln!(
2996        code,
2997        "//! Uses native Metal compute pipelines via the metal crate."
2998    )?;
2999    writeln!(code)?;
3000    writeln!(code, "#![allow(dead_code)]")?;
3001    writeln!(code)?;
3002    writeln!(code, "use metal::*;")?;
3003    writeln!(code, "#[allow(unused_imports)]")?;
3004    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
3005    writeln!(code, "use std::mem;")?;
3006    writeln!(code)?;
3007
3008    // Model constants
3009    writeln!(
3010        code,
3011        "// ── Model constants ──────────────────────────────────"
3012    )?;
3013    writeln!(
3014        code,
3015        "pub const HIDDEN_SIZE: usize = {};",
3016        config.hidden_size
3017    )?;
3018    writeln!(
3019        code,
3020        "pub const INTERMEDIATE_SIZE: usize = {};",
3021        config.intermediate_size
3022    )?;
3023    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
3024    writeln!(
3025        code,
3026        "pub const NUM_HEADS: usize = {};",
3027        config.num_attention_heads
3028    )?;
3029    writeln!(
3030        code,
3031        "pub const NUM_KV_HEADS: usize = {};",
3032        config.num_kv_heads
3033    )?;
3034    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
3035    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
3036    let effective_seq_len = config.max_seq_len.min(4096);
3037    writeln!(
3038        code,
3039        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
3040        effective_seq_len, config.max_seq_len
3041    )?;
3042    writeln!(
3043        code,
3044        "pub const RMS_NORM_EPS: f32 = {:e};",
3045        config.rms_norm_eps
3046    )?;
3047    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3048    writeln!(
3049        code,
3050        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3051    )?;
3052    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3053    writeln!(code)?;
3054
3055    Ok(())
3056}
3057
3058fn emit_metal_model_struct(
3059    code: &mut String,
3060    config: &ModelConfig,
3061) -> Result<(), MetalCodegenError> {
3062    writeln!(
3063        code,
3064        "// ── MetalModel ──────────────────────────────────────────"
3065    )?;
3066    writeln!(code)?;
3067    writeln!(
3068        code,
3069        "/// Metal-accelerated transformer model for Apple Silicon."
3070    )?;
3071    writeln!(code, "///")?;
3072    writeln!(
3073        code,
3074        "/// Uses unified memory for zero-copy weight access and native Metal"
3075    )?;
3076    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3077    writeln!(code, "pub struct MetalModel {{")?;
3078    writeln!(code, "    device: Device,")?;
3079    writeln!(code, "    queue: CommandQueue,")?;
3080    writeln!(code)?;
3081    writeln!(code, "    // ── Compute pipelines ──")?;
3082    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
3083    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
3084    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
3085    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
3086    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
3087    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
3088    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
3089    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
3090    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
3091    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
3092    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
3093    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
3094    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
3095    writeln!(
3096        code,
3097        "    copy_f32_to_f16_offset_pipeline: ComputePipelineState,"
3098    )?;
3099    writeln!(code)?;
3100    writeln!(code, "    // ── Batched prefill pipelines ──")?;
3101    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
3102    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
3103    writeln!(
3104        code,
3105        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3106    )?;
3107    writeln!(code, "    matmul_q8_mma_pipeline: ComputePipelineState,")?;
3108    writeln!(code, "    matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3109    writeln!(
3110        code,
3111        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3112    )?;
3113    writeln!(
3114        code,
3115        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3116    )?;
3117    writeln!(
3118        code,
3119        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3120    )?;
3121    if config.qkv_bias {
3122        writeln!(code, "    add_bias_batch_pipeline: ComputePipelineState,")?;
3123    }
3124    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
3125    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
3126    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
3127    writeln!(
3128        code,
3129        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
3130    )?;
3131    writeln!(
3132        code,
3133        "    add_inplace_batch_pipeline: ComputePipelineState,"
3134    )?;
3135    writeln!(
3136        code,
3137        "    copy_embedding_batch_pipeline: ComputePipelineState,"
3138    )?;
3139    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
3140    writeln!(
3141        code,
3142        "    attention_flash_batch_pipeline: ComputePipelineState,"
3143    )?;
3144    writeln!(
3145        code,
3146        "    attention_mma_flash_batch_pipeline: ComputePipelineState,"
3147    )?;
3148    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
3149    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
3150    writeln!(
3151        code,
3152        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
3153    )?;
3154    writeln!(code)?;
3155    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
3156    writeln!(
3157        code,
3158        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3159    )?;
3160    writeln!(code, "    embed_buf: Buffer,")?;
3161    writeln!(code)?;
3162    writeln!(code, "    /// Per-layer weight buffers")?;
3163    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
3164    writeln!(code)?;
3165    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
3166    writeln!(code, "    norm_buf: Buffer,")?;
3167    writeln!(code)?;
3168    writeln!(
3169        code,
3170        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3171    )?;
3172    writeln!(code, "    lm_head_buf: Buffer,")?;
3173    writeln!(code)?;
3174    writeln!(
3175        code,
3176        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
3177    )?;
3178    writeln!(code, "    hidden_buf: Buffer,")?;
3179    writeln!(code, "    residual_buf: Buffer,")?;
3180    writeln!(code, "    normed_buf: Buffer,")?;
3181    writeln!(
3182        code,
3183        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3184    )?;
3185    writeln!(code, "    qkv_buf: Buffer,")?;
3186    writeln!(code, "    attn_out_buf: Buffer,")?;
3187    writeln!(code, "    attn_proj_buf: Buffer,")?;
3188    writeln!(
3189        code,
3190        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3191    )?;
3192    writeln!(code, "    gate_up_buf: Buffer,")?;
3193    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
3194    writeln!(code, "    ffn_out_buf: Buffer,")?;
3195    writeln!(code, "    add_tmp_buf: Buffer,")?;
3196    writeln!(code, "    logits_buf: Buffer,")?;
3197    writeln!(code)?;
3198    writeln!(code, "    // ── Batched prefill working buffers ──")?;
3199    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3200    writeln!(code, "    batch_hidden_buf: Buffer,")?;
3201    writeln!(
3202        code,
3203        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3204    )?;
3205    writeln!(code, "    batch_residual_buf: Buffer,")?;
3206    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3207    writeln!(code, "    batch_qkv_buf: Buffer,")?;
3208    writeln!(
3209        code,
3210        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3211    )?;
3212    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
3213    writeln!(
3214        code,
3215        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3216    )?;
3217    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
3218    writeln!(
3219        code,
3220        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3221    )?;
3222    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
3223    writeln!(
3224        code,
3225        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3226    )?;
3227    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
3228    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3229    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
3230    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
3231    writeln!(code, "    batch_tokens_buf: Buffer,")?;
3232    writeln!(code, "    /// Positions buffer for batch RoPE")?;
3233    writeln!(code, "    batch_positions_buf: Buffer,")?;
3234    writeln!(code)?;
3235    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
3236    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
3237    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
3238    writeln!(code)?;
3239    writeln!(code, "    // ── Inference state ──")?;
3240    writeln!(code, "    pos: usize,")?;
3241    writeln!(code)?;
3242    writeln!(
3243        code,
3244        "    /// Previous command buffer for double-buffered prefill."
3245    )?;
3246    writeln!(
3247        code,
3248        "    /// While the GPU executes token N, the CPU can encode token N+1."
3249    )?;
3250    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
3251    writeln!(code, "}}")?;
3252    writeln!(code)?;
3253
3254    Ok(())
3255}
3256
3257fn emit_layer_buffers_struct(
3258    code: &mut String,
3259    config: &ModelConfig,
3260) -> Result<(), MetalCodegenError> {
3261    writeln!(
3262        code,
3263        "/// Per-layer weight buffers for attention and FFN projections."
3264    )?;
3265    writeln!(code, "struct LayerBuffers {{")?;
3266    writeln!(code, "    attn_norm: Buffer,")?;
3267    writeln!(
3268        code,
3269        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3270    )?;
3271    writeln!(code, "    qkv_weight: Buffer,")?;
3272    if config.qkv_bias {
3273        writeln!(
3274            code,
3275            "    /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3276        )?;
3277        writeln!(code, "    qkv_bias: Buffer,")?;
3278    }
3279    writeln!(code, "    o_weight: Buffer,")?;
3280    writeln!(code, "    ffn_norm: Buffer,")?;
3281    writeln!(
3282        code,
3283        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3284    )?;
3285    writeln!(code, "    gate_up_weight: Buffer,")?;
3286    writeln!(code, "    down_weight: Buffer,")?;
3287    writeln!(code, "}}")?;
3288    writeln!(code)?;
3289
3290    Ok(())
3291}
3292
3293fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3294    let hidden = config.hidden_size;
3295    let intermediate = config.intermediate_size;
3296    let _num_layers = config.num_layers;
3297    let _num_heads = config.num_attention_heads;
3298    let num_kv_heads = config.num_kv_heads;
3299    let head_dim = config.head_dim;
3300    let vocab = config.vocab_size;
3301    let effective_seq_len = config.max_seq_len.min(4096);
3302    let is_q8 = config.dtype == DType::Q8_0;
3303    let is_q4 = config.dtype == DType::Q4_0;
3304    let kv_dim = num_kv_heads * head_dim;
3305
3306    writeln!(code, "impl MetalModel {{")?;
3307
3308    // ── new() ──
3309    writeln!(
3310        code,
3311        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3312    )?;
3313    writeln!(code, "    ///")?;
3314    writeln!(
3315        code,
3316        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
3317    )?;
3318    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
3319    writeln!(
3320        code,
3321        "        let device = Device::system_default().expect(\"no Metal device found\");"
3322    )?;
3323    writeln!(code, "        let queue = device.new_command_queue();")?;
3324    writeln!(code)?;
3325
3326    // Compile shaders
3327    writeln!(
3328        code,
3329        "        // Compile Metal shaders from embedded source"
3330    )?;
3331    writeln!(
3332        code,
3333        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
3334    )?;
3335    writeln!(
3336        code,
3337        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3338    )?;
3339    writeln!(
3340        code,
3341        "            .expect(\"failed to compile Metal shaders\");"
3342    )?;
3343    writeln!(code)?;
3344
3345    // Create compute pipelines
3346    writeln!(code, "        // Create compute pipelines")?;
3347    for (var, fn_name) in [
3348        ("matmul_pipeline", "matmul_vec"),
3349        ("matmul_q8_pipeline", "matmul_vec_q8"),
3350        ("matmul_q4_pipeline", "matmul_vec_q4"),
3351        ("rms_norm_pipeline", "rms_norm"),
3352        ("rope_pipeline", "rope"),
3353        ("softmax_pipeline", "softmax"),
3354        ("silu_mul_pipeline", "silu_mul"),
3355        ("silu_mul_fused_pipeline", "silu_mul_fused"),
3356        ("add_pipeline", "elementwise_add"),
3357        ("attention_pipeline", "attention"),
3358        ("add_inplace_pipeline", "add_inplace"),
3359        ("copy_pipeline", "copy_buffer"),
3360        ("copy_offset_pipeline", "copy_offset"),
3361        ("copy_f32_to_f16_offset_pipeline", "copy_f32_to_f16_offset"),
3362        ("matmul_batch_pipeline", "matmul_vec_batch"),
3363        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3364        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3365        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3366        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3367        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3368        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3369        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3370        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3371        ("rms_norm_batch_pipeline", "rms_norm_batch"),
3372        ("rope_batch_pipeline", "rope_batch"),
3373        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3374        ("add_inplace_batch_pipeline", "add_inplace_batch"),
3375        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3376        ("attention_batch_pipeline", "attention_batch"),
3377        ("attention_flash_batch_pipeline", "attention_flash_batch"),
3378        (
3379            "attention_mma_flash_batch_pipeline",
3380            "attention_mma_flash_batch",
3381        ),
3382        ("copy_kv_batch_pipeline", "copy_kv_batch"),
3383        ("rope_qk_batch_pipeline", "rope_qk_batch"),
3384        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3385    ] {
3386        writeln!(
3387            code,
3388            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3389        )?;
3390    }
3391    if config.qkv_bias {
3392        writeln!(
3393            code,
3394            "        let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3395        )?;
3396    }
3397    writeln!(code)?;
3398
3399    // Weight loading
3400    writeln!(
3401        code,
3402        "        // Load weights into Metal shared-memory buffers"
3403    )?;
3404    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
3405    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3406    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
3407    writeln!(code)?;
3408    writeln!(
3409        code,
3410        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
3411    )?;
3412    writeln!(code)?;
3413    writeln!(
3414        code,
3415        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3416    )?;
3417    writeln!(
3418        code,
3419        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3420    )?;
3421    writeln!(code, "            let byte_len = n * f32_size;")?;
3422    writeln!(code, "            let cur = cursor.get();")?;
3423    writeln!(
3424        code,
3425        "            let data = &weights[cur..cur + byte_len];"
3426    )?;
3427    writeln!(code, "            cursor.set(cur + byte_len);")?;
3428    writeln!(code, "            device.new_buffer_with_data(")?;
3429    writeln!(code, "                data.as_ptr() as *const _,")?;
3430    writeln!(code, "                byte_len as u64,")?;
3431    writeln!(
3432        code,
3433        "                MTLResourceOptions::StorageModeShared,"
3434    )?;
3435    writeln!(code, "            )")?;
3436    writeln!(code, "        }};")?;
3437    writeln!(code)?;
3438
3439    if is_q8 {
3440        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
3441        // We load them directly into Metal buffers without dequantizing,
3442        // and use the matmul_vec_q8 shader that operates on quantized data.
3443        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
3444        writeln!(
3445            code,
3446            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3447        )?;
3448        writeln!(
3449            code,
3450            "        // as raw bytes into a Metal buffer (no dequantization)."
3451        )?;
3452        writeln!(
3453            code,
3454            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3455        )?;
3456        writeln!(
3457            code,
3458            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3459        )?;
3460        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3461        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3462        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3463        writeln!(code, "            let cur = cursor.get();")?;
3464        writeln!(
3465            code,
3466            "            let data = &weights[cur..cur + total_raw];"
3467        )?;
3468        writeln!(code, "            cursor.set(cur + total_raw);")?;
3469        writeln!(code, "            device.new_buffer_with_data(")?;
3470        writeln!(code, "                data.as_ptr() as *const _,")?;
3471        writeln!(code, "                total_raw as u64,")?;
3472        writeln!(
3473            code,
3474            "                MTLResourceOptions::StorageModeShared,"
3475        )?;
3476        writeln!(code, "            )")?;
3477        writeln!(code, "        }};")?;
3478        writeln!(code)?;
3479        writeln!(
3480            code,
3481            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3482        )?;
3483        writeln!(
3484            code,
3485            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3486        )?;
3487        writeln!(
3488            code,
3489            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3490        )?;
3491        writeln!(
3492            code,
3493            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3494        )?;
3495        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3496        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3497        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3498        writeln!(code, "            let cur = cursor.get();")?;
3499        writeln!(
3500            code,
3501            "            let data = &weights[cur..cur + total_raw];"
3502        )?;
3503        writeln!(code, "            cursor.set(cur + total_raw);")?;
3504        writeln!(code, "            device.new_buffer_with_data(")?;
3505        writeln!(code, "                data.as_ptr() as *const _,")?;
3506        writeln!(code, "                total_raw as u64,")?;
3507        writeln!(
3508            code,
3509            "                MTLResourceOptions::StorageModeShared,"
3510        )?;
3511        writeln!(code, "            )")?;
3512        writeln!(code, "        }};")?;
3513        writeln!(code)?;
3514    }
3515
3516    if is_q4 {
3517        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3518        // We load them directly into Metal buffers without dequantizing,
3519        // and use the matmul_vec_q4 shader that operates on quantized data.
3520        // This quarters GPU memory usage vs f32 dequantization.
3521        writeln!(
3522            code,
3523            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3524        )?;
3525        writeln!(
3526            code,
3527            "        // as raw bytes into a Metal buffer (no dequantization)."
3528        )?;
3529        writeln!(
3530            code,
3531            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3532        )?;
3533        writeln!(
3534            code,
3535            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3536        )?;
3537        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3538        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3539        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3540        writeln!(code, "            let cur = cursor.get();")?;
3541        writeln!(
3542            code,
3543            "            let data = &weights[cur..cur + total_raw];"
3544        )?;
3545        writeln!(code, "            cursor.set(cur + total_raw);")?;
3546        writeln!(code, "            device.new_buffer_with_data(")?;
3547        writeln!(code, "                data.as_ptr() as *const _,")?;
3548        writeln!(code, "                total_raw as u64,")?;
3549        writeln!(
3550            code,
3551            "                MTLResourceOptions::StorageModeShared,"
3552        )?;
3553        writeln!(code, "            )")?;
3554        writeln!(code, "        }};")?;
3555        writeln!(code)?;
3556        writeln!(
3557            code,
3558            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3559        )?;
3560        writeln!(
3561            code,
3562            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3563        )?;
3564        writeln!(
3565            code,
3566            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3567        )?;
3568        writeln!(
3569            code,
3570            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3571        )?;
3572        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3573        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3574        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3575        writeln!(code, "            let cur = cursor.get();")?;
3576        writeln!(
3577            code,
3578            "            let data = &weights[cur..cur + total_raw];"
3579        )?;
3580        writeln!(code, "            cursor.set(cur + total_raw);")?;
3581        writeln!(code, "            device.new_buffer_with_data(")?;
3582        writeln!(code, "                data.as_ptr() as *const _,")?;
3583        writeln!(code, "                total_raw as u64,")?;
3584        writeln!(
3585            code,
3586            "                MTLResourceOptions::StorageModeShared,"
3587        )?;
3588        writeln!(code, "            )")?;
3589        writeln!(code, "        }};")?;
3590        writeln!(code)?;
3591    }
3592
3593    writeln!(
3594        code,
3595        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3596    )?;
3597    writeln!(code)?;
3598
3599    // Per-layer weights
3600    writeln!(
3601        code,
3602        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3603    )?;
3604    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3605
3606    // attn_norm is always f32
3607    writeln!(
3608        code,
3609        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3610    )?;
3611
3612    let qkv_rows = hidden + 2 * kv_dim;
3613    if is_q8 {
3614        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3615        writeln!(
3616            code,
3617            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3618        )?;
3619        if config.qkv_bias {
3620            writeln!(
3621                code,
3622                "            // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3623            )?;
3624            writeln!(
3625                code,
3626                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3627            )?;
3628        }
3629        writeln!(
3630            code,
3631            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3632        )?;
3633    } else if is_q4 {
3634        writeln!(
3635            code,
3636            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3637        )?;
3638        if config.qkv_bias {
3639            writeln!(
3640                code,
3641                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3642            )?;
3643        }
3644        writeln!(
3645            code,
3646            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3647        )?;
3648    } else {
3649        writeln!(
3650            code,
3651            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3652        )?;
3653        if config.qkv_bias {
3654            writeln!(
3655                code,
3656                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3657            )?;
3658        }
3659        writeln!(
3660            code,
3661            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3662        )?;
3663    }
3664
3665    // ffn_norm is always f32
3666    writeln!(
3667        code,
3668        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3669    )?;
3670
3671    let gate_up_rows = 2 * intermediate;
3672    if is_q8 {
3673        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3674        writeln!(
3675            code,
3676            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3677        )?;
3678        writeln!(
3679            code,
3680            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3681        )?;
3682    } else if is_q4 {
3683        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3684        writeln!(
3685            code,
3686            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3687        )?;
3688        writeln!(
3689            code,
3690            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3691        )?;
3692    } else {
3693        // Fused gate+up weight: read both as a single contiguous f32 buffer
3694        writeln!(
3695            code,
3696            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3697        )?;
3698        writeln!(
3699            code,
3700            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3701        )?;
3702    }
3703
3704    writeln!(code, "            layers.push(LayerBuffers {{")?;
3705    writeln!(code, "                attn_norm,")?;
3706    writeln!(code, "                qkv_weight,")?;
3707    if config.qkv_bias {
3708        writeln!(code, "                qkv_bias,")?;
3709    }
3710    writeln!(code, "                o_weight,")?;
3711    writeln!(code, "                ffn_norm,")?;
3712    writeln!(code, "                gate_up_weight,")?;
3713    writeln!(code, "                down_weight,")?;
3714    writeln!(code, "            }});")?;
3715    writeln!(code, "        }}")?;
3716    writeln!(code)?;
3717
3718    // final_norm is always f32
3719    writeln!(
3720        code,
3721        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3722    )?;
3723    writeln!(code)?;
3724
3725    // lm_head
3726    if is_q8 {
3727        writeln!(
3728            code,
3729            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3730        )?;
3731    } else if is_q4 {
3732        writeln!(
3733            code,
3734            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3735        )?;
3736    } else {
3737        writeln!(
3738            code,
3739            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3740        )?;
3741    }
3742    writeln!(code)?;
3743
3744    // Working buffers
3745    let hidden_bytes = hidden * 4;
3746    let _kv_dim_bytes = kv_dim * 4;
3747    let intermediate_bytes = intermediate * 4;
3748    let vocab_bytes = vocab * 4;
3749    // KV cache is stored as f16 (2 bytes/element) to halve attention memory
3750    // bandwidth in long-context decode.
3751    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 2;
3752
3753    writeln!(
3754        code,
3755        "        // Allocate working buffers (shared memory for zero-copy)"
3756    )?;
3757    writeln!(
3758        code,
3759        "        let opts = MTLResourceOptions::StorageModeShared;"
3760    )?;
3761    writeln!(
3762        code,
3763        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3764    )?;
3765    writeln!(
3766        code,
3767        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3768    )?;
3769    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3770    writeln!(
3771        code,
3772        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3773    )?;
3774    writeln!(
3775        code,
3776        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3777    )?;
3778    writeln!(
3779        code,
3780        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3781    )?;
3782    writeln!(
3783        code,
3784        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3785    )?;
3786    writeln!(
3787        code,
3788        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3789    )?;
3790    let gate_up_buf_bytes = 2 * intermediate * 4;
3791    writeln!(
3792        code,
3793        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3794    )?;
3795    writeln!(
3796        code,
3797        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3798    )?;
3799    writeln!(
3800        code,
3801        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3802    )?;
3803    writeln!(
3804        code,
3805        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3806    )?;
3807    writeln!(
3808        code,
3809        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3810    )?;
3811    writeln!(
3812        code,
3813        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3814    )?;
3815    writeln!(code)?;
3816
3817    // Batch prefill working buffers
3818    let batch_hidden_bytes = hidden * 4; // per-token
3819    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3820    let batch_gate_up_bytes = 2 * intermediate * 4;
3821    let batch_intermediate_bytes = intermediate * 4;
3822    writeln!(
3823        code,
3824        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3825    )?;
3826    writeln!(
3827        code,
3828        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3829    )?;
3830    writeln!(
3831        code,
3832        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3833    )?;
3834    writeln!(
3835        code,
3836        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3837    )?;
3838    writeln!(
3839        code,
3840        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3841    )?;
3842    writeln!(
3843        code,
3844        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3845    )?;
3846    writeln!(
3847        code,
3848        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3849    )?;
3850    writeln!(
3851        code,
3852        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3853    )?;
3854    writeln!(
3855        code,
3856        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3857    )?;
3858    writeln!(
3859        code,
3860        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3861    )?;
3862    writeln!(
3863        code,
3864        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3865    )?;
3866    writeln!(code)?;
3867
3868    // KV cache buffers
3869    writeln!(code, "        // KV cache buffers (per-layer)")?;
3870    writeln!(
3871        code,
3872        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3873    )?;
3874    writeln!(
3875        code,
3876        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3877    )?;
3878    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3879    writeln!(
3880        code,
3881        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3882    )?;
3883    writeln!(
3884        code,
3885        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3886    )?;
3887    writeln!(code, "        }}")?;
3888    writeln!(code)?;
3889
3890    writeln!(code, "        Self {{")?;
3891    writeln!(code, "            device,")?;
3892    writeln!(code, "            queue,")?;
3893    writeln!(code, "            matmul_pipeline,")?;
3894    writeln!(code, "            matmul_q8_pipeline,")?;
3895    writeln!(code, "            matmul_q4_pipeline,")?;
3896    writeln!(code, "            rms_norm_pipeline,")?;
3897    writeln!(code, "            rope_pipeline,")?;
3898    writeln!(code, "            softmax_pipeline,")?;
3899    writeln!(code, "            silu_mul_pipeline,")?;
3900    writeln!(code, "            silu_mul_fused_pipeline,")?;
3901    writeln!(code, "            add_pipeline,")?;
3902    writeln!(code, "            attention_pipeline,")?;
3903    writeln!(code, "            add_inplace_pipeline,")?;
3904    writeln!(code, "            copy_pipeline,")?;
3905    writeln!(code, "            copy_offset_pipeline,")?;
3906    writeln!(code, "            copy_f32_to_f16_offset_pipeline,")?;
3907    writeln!(code, "            matmul_batch_pipeline,")?;
3908    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3909    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3910    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3911    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3912    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3913    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3914    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3915    if config.qkv_bias {
3916        writeln!(code, "            add_bias_batch_pipeline,")?;
3917    }
3918    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3919    writeln!(code, "            rms_norm_batch_pipeline,")?;
3920    writeln!(code, "            rope_batch_pipeline,")?;
3921    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3922    writeln!(code, "            add_inplace_batch_pipeline,")?;
3923    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3924    writeln!(code, "            attention_batch_pipeline,")?;
3925    writeln!(code, "            attention_flash_batch_pipeline,")?;
3926    writeln!(code, "            attention_mma_flash_batch_pipeline,")?;
3927    writeln!(code, "            copy_kv_batch_pipeline,")?;
3928    writeln!(code, "            rope_qk_batch_pipeline,")?;
3929    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3930    writeln!(code, "            embed_buf,")?;
3931    writeln!(code, "            layers,")?;
3932    writeln!(code, "            norm_buf,")?;
3933    writeln!(code, "            lm_head_buf,")?;
3934    writeln!(code, "            hidden_buf,")?;
3935    writeln!(code, "            residual_buf,")?;
3936    writeln!(code, "            normed_buf,")?;
3937    writeln!(code, "            qkv_buf,")?;
3938    writeln!(code, "            attn_out_buf,")?;
3939    writeln!(code, "            attn_proj_buf,")?;
3940    writeln!(code, "            gate_up_buf,")?;
3941    writeln!(code, "            ffn_hidden_buf,")?;
3942    writeln!(code, "            ffn_out_buf,")?;
3943    writeln!(code, "            add_tmp_buf,")?;
3944    writeln!(code, "            logits_buf,")?;
3945    writeln!(code, "            batch_hidden_buf,")?;
3946    writeln!(code, "            batch_residual_buf,")?;
3947    writeln!(code, "            batch_qkv_buf,")?;
3948    writeln!(code, "            batch_attn_out_buf,")?;
3949    writeln!(code, "            batch_attn_proj_buf,")?;
3950    writeln!(code, "            batch_gate_up_buf,")?;
3951    writeln!(code, "            batch_ffn_hidden_buf,")?;
3952    writeln!(code, "            batch_ffn_out_buf,")?;
3953    writeln!(code, "            batch_tokens_buf,")?;
3954    writeln!(code, "            batch_positions_buf,")?;
3955    writeln!(code, "            k_cache,")?;
3956    writeln!(code, "            v_cache,")?;
3957    writeln!(code, "            pos: 0,")?;
3958    writeln!(code, "            prev_cmd: None,")?;
3959    writeln!(code, "        }}")?;
3960    writeln!(code, "    }}")?;
3961    writeln!(code)?;
3962
3963    // ── forward() ──
3964    writeln!(
3965        code,
3966        "    /// Run the forward pass for a single token at the current position."
3967    )?;
3968    writeln!(code, "    ///")?;
3969    writeln!(
3970        code,
3971        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3972    )?;
3973    writeln!(code, "    ///")?;
3974    writeln!(
3975        code,
3976        "    /// All GPU operations are encoded into a single command buffer and"
3977    )?;
3978    writeln!(
3979        code,
3980        "    /// committed once at the end, avoiding per-operation synchronization."
3981    )?;
3982    writeln!(
3983        code,
3984        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3985    )?;
3986    writeln!(
3987        code,
3988        "        // Wait for any pending prefill command buffer"
3989    )?;
3990    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3991    writeln!(code, "            prev.wait_until_completed();")?;
3992    writeln!(code, "        }}")?;
3993    writeln!(code)?;
3994    writeln!(code, "        let pos = self.pos;")?;
3995    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3996    writeln!(code)?;
3997
3998    // Single compute encoder for the entire forward pass — no blit encoder
3999    // transitions. Copy operations use compute copy kernels instead of blits.
4000    let matmul_fn = if is_q8 {
4001        "dispatch_matmul_q8"
4002    } else if is_q4 {
4003        "dispatch_matmul_q4"
4004    } else {
4005        "dispatch_matmul"
4006    };
4007
4008    writeln!(
4009        code,
4010        "        // Single compute encoder for the entire forward pass (no blit transitions)"
4011    )?;
4012    writeln!(code, "        {{")?;
4013    writeln!(
4014        code,
4015        "            let enc = cmd.new_compute_command_encoder();"
4016    )?;
4017    writeln!(code)?;
4018
4019    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
4020    writeln!(
4021        code,
4022        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
4023    )?;
4024    writeln!(
4025        code,
4026        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
4027    )?;
4028    writeln!(
4029        code,
4030        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
4031    )?;
4032    writeln!(
4033        code,
4034        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
4035        hidden * 4,
4036    )?;
4037    writeln!(code, "            unsafe {{")?;
4038    writeln!(
4039        code,
4040        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
4041    )?;
4042    writeln!(
4043        code,
4044        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4045    )?;
4046    writeln!(
4047        code,
4048        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
4049    )?;
4050    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
4051    writeln!(
4052        code,
4053        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4054    )?;
4055    writeln!(code, "                    hidden_ptr,")?;
4056    writeln!(code, "                    HIDDEN_SIZE,")?;
4057    writeln!(code, "                );")?;
4058    writeln!(
4059        code,
4060        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4061    )?;
4062    writeln!(code, "            }}")?;
4063    writeln!(code)?;
4064
4065    // 2. Transformer layers
4066    writeln!(code, "            // 2. Transformer layers")?;
4067    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4068    writeln!(code)?;
4069    let q_byte_offset = 0usize;
4070    let k_float_offset = hidden;
4071    let v_float_offset = hidden + kv_dim;
4072
4073    writeln!(
4074        code,
4075        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
4076    )?;
4077    writeln!(
4078        code,
4079        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4080    )?;
4081    writeln!(
4082        code,
4083        "                // Fused Q+K+V matmul: single dispatch for all three projections"
4084    )?;
4085    writeln!(
4086        code,
4087        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4088    )?;
4089    if config.qkv_bias {
4090        writeln!(
4091            code,
4092            "                // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4093        )?;
4094        writeln!(
4095            code,
4096            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4097        )?;
4098    }
4099    writeln!(
4100        code,
4101        "                // Fused Q+K RoPE in one dispatch (saves 1 dispatch + barrier vs separate Q and K rope)"
4102    )?;
4103    writeln!(
4104        code,
4105        "                self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4106    )?;
4107    writeln!(code)?;
4108    writeln!(
4109        code,
4110        "                // Fused K+V cache update in one dispatch (f32 qkv_buf -> f16 KV cache)"
4111    )?;
4112    writeln!(
4113        code,
4114        "                self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4115    )?;
4116    writeln!(code)?;
4117    writeln!(
4118        code,
4119        "                // Attention using Q from qkv_buf (offset 0)"
4120    )?;
4121    writeln!(
4122        code,
4123        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4124    )?;
4125    writeln!(
4126        code,
4127        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4128    )?;
4129    writeln!(
4130        code,
4131        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4132    )?;
4133    writeln!(
4134        code,
4135        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4136    )?;
4137    writeln!(
4138        code,
4139        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4140    )?;
4141    writeln!(
4142        code,
4143        "                // Fused gate+up matmul: single dispatch for both projections"
4144    )?;
4145    writeln!(
4146        code,
4147        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4148    )?;
4149    writeln!(
4150        code,
4151        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4152    )?;
4153    writeln!(
4154        code,
4155        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4156    )?;
4157    writeln!(
4158        code,
4159        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4160    )?;
4161    writeln!(code, "            }}")?;
4162    writeln!(code)?;
4163
4164    // 3. Final RMS norm + logits
4165    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
4166    writeln!(
4167        code,
4168        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4169    )?;
4170    writeln!(
4171        code,
4172        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4173    )?;
4174    writeln!(code)?;
4175    writeln!(code, "            enc.end_encoding();")?;
4176    writeln!(code, "        }}")?;
4177    writeln!(code)?;
4178
4179    // 5. Single commit + wait, then read back logits
4180    writeln!(
4181        code,
4182        "        // 5. Commit all GPU work and wait for completion"
4183    )?;
4184    writeln!(code, "        cmd.commit();")?;
4185    writeln!(code, "        cmd.wait_until_completed();")?;
4186    writeln!(code)?;
4187    writeln!(code, "        // 6. Read back logits from GPU")?;
4188    writeln!(code, "        let logits = unsafe {{")?;
4189    writeln!(
4190        code,
4191        "            let ptr = self.logits_buf.contents() as *const f32;"
4192    )?;
4193    writeln!(
4194        code,
4195        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4196    )?;
4197    writeln!(code, "        }};")?;
4198    writeln!(code)?;
4199    writeln!(code, "        self.pos += 1;")?;
4200    writeln!(code, "        logits")?;
4201    writeln!(code, "    }}")?;
4202    writeln!(code)?;
4203
4204    // ── forward_profile: instrumented forward with per-operation timing ──
4205    writeln!(
4206        code,
4207        "    /// Profiling forward pass that prints per-stage GPU timing."
4208    )?;
4209    writeln!(code, "    ///")?;
4210    writeln!(
4211        code,
4212        "    /// Each stage is committed and waited on separately so that GPU timestamps"
4213    )?;
4214    writeln!(
4215        code,
4216        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
4217    )?;
4218    writeln!(
4219        code,
4220        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
4221    )?;
4222    writeln!(
4223        code,
4224        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4225    )?;
4226    writeln!(code, "        use std::time::Instant;")?;
4227    writeln!(code)?;
4228    writeln!(
4229        code,
4230        "        // Wait for any pending prefill command buffer"
4231    )?;
4232    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4233    writeln!(code, "            prev.wait_until_completed();")?;
4234    writeln!(code, "        }}")?;
4235    writeln!(code)?;
4236    writeln!(code, "        let pos = self.pos;")?;
4237    writeln!(code)?;
4238
4239    // Stage: embedding (CPU, no GPU)
4240    writeln!(
4241        code,
4242        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
4243    )?;
4244    writeln!(code, "        let t_embed = Instant::now();")?;
4245    writeln!(code, "        unsafe {{")?;
4246    writeln!(
4247        code,
4248        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
4249    )?;
4250    writeln!(
4251        code,
4252        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4253    )?;
4254    writeln!(
4255        code,
4256        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
4257    )?;
4258    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
4259    writeln!(
4260        code,
4261        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4262    )?;
4263    writeln!(code, "                hidden_ptr,")?;
4264    writeln!(code, "                HIDDEN_SIZE,")?;
4265    writeln!(code, "            );")?;
4266    writeln!(
4267        code,
4268        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4269    )?;
4270    writeln!(code, "        }}")?;
4271    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
4272    writeln!(code)?;
4273
4274    // Stage: Transformer layers (all together on GPU)
4275    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
4276    writeln!(code, "        let t_layers = Instant::now();")?;
4277    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4278    writeln!(code, "        {{")?;
4279    writeln!(
4280        code,
4281        "            let enc = cmd.new_compute_command_encoder();"
4282    )?;
4283    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4284    writeln!(
4285        code,
4286        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4287    )?;
4288    writeln!(
4289        code,
4290        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4291    )?;
4292    if config.qkv_bias {
4293        writeln!(
4294            code,
4295            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4296        )?;
4297    }
4298    writeln!(
4299        code,
4300        "                self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4301    )?;
4302    writeln!(
4303        code,
4304        "                self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4305    )?;
4306    writeln!(
4307        code,
4308        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4309    )?;
4310    writeln!(
4311        code,
4312        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4313    )?;
4314    writeln!(
4315        code,
4316        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4317    )?;
4318    writeln!(
4319        code,
4320        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4321    )?;
4322    writeln!(
4323        code,
4324        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4325    )?;
4326    writeln!(
4327        code,
4328        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4329    )?;
4330    writeln!(
4331        code,
4332        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4333    )?;
4334    writeln!(
4335        code,
4336        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4337    )?;
4338    writeln!(code, "            }}")?;
4339    writeln!(code, "            enc.end_encoding();")?;
4340    writeln!(code, "        }}")?;
4341    writeln!(code, "        cmd.commit();")?;
4342    writeln!(code, "        cmd.wait_until_completed();")?;
4343    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
4344    writeln!(code)?;
4345
4346    // Stage: Final norm + logits
4347    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
4348    writeln!(code, "        let t_logits = Instant::now();")?;
4349    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
4350    writeln!(code, "        {{")?;
4351    writeln!(
4352        code,
4353        "            let enc = cmd2.new_compute_command_encoder();"
4354    )?;
4355    writeln!(
4356        code,
4357        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4358    )?;
4359    writeln!(
4360        code,
4361        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4362    )?;
4363    writeln!(code, "            enc.end_encoding();")?;
4364    writeln!(code, "        }}")?;
4365    writeln!(code, "        cmd2.commit();")?;
4366    writeln!(code, "        cmd2.wait_until_completed();")?;
4367    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
4368    writeln!(code)?;
4369
4370    // Print profile results
4371    writeln!(
4372        code,
4373        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4374    )?;
4375    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
4376    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
4377    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
4378    writeln!(
4379        code,
4380        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4381    )?;
4382    writeln!(code)?;
4383
4384    // Read back logits
4385    writeln!(code, "        let logits = unsafe {{")?;
4386    writeln!(
4387        code,
4388        "            let ptr = self.logits_buf.contents() as *const f32;"
4389    )?;
4390    writeln!(
4391        code,
4392        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4393    )?;
4394    writeln!(code, "        }};")?;
4395    writeln!(code)?;
4396    writeln!(code, "        self.pos += 1;")?;
4397    writeln!(code, "        logits")?;
4398    writeln!(code, "    }}")?;
4399    writeln!(code)?;
4400
4401    // ── forward_prefill: single-token async forward (backward compat) ──
4402    writeln!(
4403        code,
4404        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
4405    )?;
4406    writeln!(code, "    ///")?;
4407    writeln!(
4408        code,
4409        "    /// Commits the command buffer without waiting, enabling double-buffered"
4410    )?;
4411    writeln!(
4412        code,
4413        "    /// execution: GPU processes token N while CPU encodes token N+1."
4414    )?;
4415    writeln!(
4416        code,
4417        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
4418    )?;
4419    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
4420    writeln!(code, "    }}")?;
4421    writeln!(code)?;
4422
4423    // ── forward_prefill_batch: batched prefill for multiple tokens ──
4424    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
4425    let batch_matmul_fn = if is_q8 {
4426        "dispatch_matmul_q8_batch"
4427    } else if is_q4 {
4428        "dispatch_matmul_q4_batch"
4429    } else {
4430        "dispatch_matmul_batch"
4431    };
4432
4433    writeln!(
4434        code,
4435        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4436    )?;
4437    writeln!(code, "    ///")?;
4438    writeln!(
4439        code,
4440        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4441    )?;
4442    writeln!(
4443        code,
4444        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
4445    )?;
4446    writeln!(
4447        code,
4448        "    /// This provides significant speedup during prompt prefill."
4449    )?;
4450    writeln!(
4451        code,
4452        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4453    )?;
4454    writeln!(code, "        if tokens.is_empty() {{ return; }}")?;
4455    writeln!(
4456        code,
4457        "        // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4458    )?;
4459    writeln!(
4460        code,
4461        "        // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4462    )?;
4463    writeln!(
4464        code,
4465        "        // longer than that must be processed iteratively.  The KV cache"
4466    )?;
4467    writeln!(code, "        // carries state across chunks via self.pos.")?;
4468    writeln!(
4469        code,
4470        "        for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4471    )?;
4472    writeln!(code, "        let m = chunk.len();")?;
4473    writeln!(code, "        if m == 0 {{ continue; }}")?;
4474    writeln!(code, "        let start_pos = self.pos;")?;
4475    writeln!(code)?;
4476    writeln!(code, "        // Wait for any pending command buffer")?;
4477    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4478    writeln!(code, "            prev.wait_until_completed();")?;
4479    writeln!(code, "        }}")?;
4480    writeln!(code)?;
4481
4482    // Upload token IDs and positions to GPU
4483    writeln!(
4484        code,
4485        "        // Upload token IDs and positions to GPU buffers"
4486    )?;
4487    writeln!(code, "        unsafe {{")?;
4488    writeln!(
4489        code,
4490        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4491    )?;
4492    writeln!(
4493        code,
4494        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4495    )?;
4496    writeln!(code, "            for i in 0..m {{")?;
4497    writeln!(code, "                *tok_ptr.add(i) = chunk[i];")?;
4498    writeln!(
4499        code,
4500        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
4501    )?;
4502    writeln!(code, "            }}")?;
4503    writeln!(code, "        }}")?;
4504    writeln!(code)?;
4505
4506    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4507    writeln!(code, "        {{")?;
4508    writeln!(
4509        code,
4510        "            let enc = cmd.new_compute_command_encoder();"
4511    )?;
4512    writeln!(code)?;
4513
4514    // 1. Batch embedding lookup
4515    writeln!(
4516        code,
4517        "            // 1. Batch embedding lookup: copy all token embeddings at once"
4518    )?;
4519    writeln!(
4520        code,
4521        "            self.dispatch_copy_embedding_batch(&enc, m);"
4522    )?;
4523    // Copy batch_hidden -> batch_residual
4524    writeln!(
4525        code,
4526        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4527    )?;
4528    writeln!(code)?;
4529
4530    // 2. Transformer layers
4531    writeln!(code, "            // 2. Transformer layers")?;
4532    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4533    writeln!(code)?;
4534
4535    // Batch RMS norm: residual -> hidden (batched)
4536    writeln!(
4537        code,
4538        "                // Batch RMS norm: batch_residual -> batch_hidden"
4539    )?;
4540    writeln!(
4541        code,
4542        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4543    )?;
4544
4545    // Batch QKV matmul
4546    writeln!(
4547        code,
4548        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4549    )?;
4550    writeln!(
4551        code,
4552        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4553    )?;
4554    if config.qkv_bias {
4555        writeln!(
4556            code,
4557            "                // Qwen2: broadcast-add QKV bias across all M tokens."
4558        )?;
4559        writeln!(
4560            code,
4561            "                self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4562        )?;
4563    }
4564    writeln!(code)?;
4565
4566    // Fused RoPE on Q+K portions in a single dispatch
4567    let k_float_offset = hidden;
4568    writeln!(
4569        code,
4570        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4571    )?;
4572    writeln!(
4573        code,
4574        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4575    )?;
4576    writeln!(code)?;
4577
4578    // Fused KV cache update: copy both K and V in a single dispatch
4579    let v_float_offset = hidden + kv_dim;
4580    writeln!(
4581        code,
4582        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4583    )?;
4584    writeln!(
4585        code,
4586        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4587    )?;
4588    writeln!(code)?;
4589
4590    // Batched causal attention: ONE dispatch for all M tokens
4591    writeln!(
4592        code,
4593        "                // Batched causal attention: one dispatch for all M tokens"
4594    )?;
4595    writeln!(
4596        code,
4597        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4598    )?;
4599    writeln!(code)?;
4600
4601    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4602    writeln!(code, "                // Batched O projection")?;
4603    writeln!(
4604        code,
4605        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4606    )?;
4607    writeln!(code)?;
4608
4609    // Batch add: residual += attn_proj for all tokens
4610    writeln!(
4611        code,
4612        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4613    )?;
4614    writeln!(code)?;
4615
4616    // Batch FFN
4617    writeln!(
4618        code,
4619        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4620    )?;
4621    writeln!(
4622        code,
4623        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4624    )?;
4625    writeln!(
4626        code,
4627        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4628    )?;
4629    writeln!(
4630        code,
4631        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4632    )?;
4633    writeln!(
4634        code,
4635        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4636    )?;
4637    writeln!(
4638        code,
4639        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4640    )?;
4641    writeln!(code, "            }}")?;
4642    writeln!(code)?;
4643
4644    // Copy last token's residual to single-token residual_buf for next forward() call
4645    writeln!(
4646        code,
4647        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4648    )?;
4649    writeln!(
4650        code,
4651        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4652    )?;
4653    writeln!(code)?;
4654    writeln!(code, "            enc.end_encoding();")?;
4655    writeln!(code, "        }}")?;
4656    writeln!(code)?;
4657
4658    writeln!(code, "        cmd.commit();")?;
4659    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4660    writeln!(code, "        self.pos += m;")?;
4661    writeln!(code, "        }}  // end for chunk")?;
4662    writeln!(code, "    }}")?;
4663    writeln!(code)?;
4664
4665    // ── reset() — rewind KV cache position for new inference requests ──
4666    writeln!(
4667        code,
4668        "    /// Reset the model state for a new inference request."
4669    )?;
4670    writeln!(code, "    pub fn reset(&mut self) {{")?;
4671    writeln!(code, "        self.pos = 0;")?;
4672    writeln!(code, "        self.prev_cmd = None;")?;
4673    writeln!(code, "    }}")?;
4674    writeln!(code)?;
4675
4676    // ── Private dispatch helpers (all take a shared compute encoder) ──
4677    writeln!(
4678        code,
4679        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4680    )?;
4681    writeln!(
4682        code,
4683        "    // These methods set pipeline state + buffers + dispatch on an existing"
4684    )?;
4685    writeln!(
4686        code,
4687        "    // encoder, avoiding per-operation encoder creation overhead."
4688    )?;
4689    writeln!(code)?;
4690
4691    // dispatch_rms_norm
4692    writeln!(
4693        code,
4694        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4695    )?;
4696    writeln!(
4697        code,
4698        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4699    )?;
4700    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4701    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4702    writeln!(
4703        code,
4704        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4705    )?;
4706    writeln!(
4707        code,
4708        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4709    )?;
4710    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4711    writeln!(
4712        code,
4713        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4714    )?;
4715    writeln!(
4716        code,
4717        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4718    )?;
4719    writeln!(
4720        code,
4721        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4722    )?;
4723    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4724    writeln!(
4725        code,
4726        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4727    )?;
4728    writeln!(
4729        code,
4730        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4731    )?;
4732    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4733    writeln!(code, "    }}")?;
4734    writeln!(code)?;
4735
4736    // dispatch_matmul
4737    writeln!(
4738        code,
4739        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4740    )?;
4741    writeln!(
4742        code,
4743        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4744    )?;
4745    writeln!(code, "        let r: u32 = rows as u32;")?;
4746    writeln!(code, "        let c: u32 = cols as u32;")?;
4747    writeln!(
4748        code,
4749        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4750    )?;
4751    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4752    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4753    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4754    writeln!(
4755        code,
4756        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4757    )?;
4758    writeln!(
4759        code,
4760        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4761    )?;
4762    writeln!(
4763        code,
4764        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4765    )?;
4766    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4767    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4768    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4769    writeln!(
4770        code,
4771        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4772    )?;
4773    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4774    writeln!(code, "    }}")?;
4775    writeln!(code)?;
4776
4777    // dispatch_matmul_q8
4778    writeln!(
4779        code,
4780        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4781    )?;
4782    writeln!(
4783        code,
4784        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4785    )?;
4786    writeln!(
4787        code,
4788        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4789    )?;
4790    writeln!(code, "        let r: u32 = rows as u32;")?;
4791    writeln!(code, "        let c: u32 = cols as u32;")?;
4792    writeln!(
4793        code,
4794        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4795    )?;
4796    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4797    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4798    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4799    writeln!(
4800        code,
4801        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4802    )?;
4803    writeln!(
4804        code,
4805        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4806    )?;
4807    writeln!(
4808        code,
4809        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4810    )?;
4811    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4812    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4813    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4814    writeln!(
4815        code,
4816        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4817    )?;
4818    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4819    writeln!(code, "    }}")?;
4820    writeln!(code)?;
4821
4822    // dispatch_matmul_q4
4823    writeln!(
4824        code,
4825        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4826    )?;
4827    writeln!(
4828        code,
4829        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4830    )?;
4831    writeln!(
4832        code,
4833        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4834    )?;
4835    writeln!(code, "        let r: u32 = rows as u32;")?;
4836    writeln!(code, "        let c: u32 = cols as u32;")?;
4837    writeln!(
4838        code,
4839        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4840    )?;
4841    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4842    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4843    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4844    writeln!(
4845        code,
4846        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4847    )?;
4848    writeln!(
4849        code,
4850        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4851    )?;
4852    writeln!(
4853        code,
4854        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4855    )?;
4856    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4857    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4858    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4859    writeln!(
4860        code,
4861        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4862    )?;
4863    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4864    writeln!(code, "    }}")?;
4865    writeln!(code)?;
4866
4867    // dispatch_rope
4868    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4869    writeln!(
4870        code,
4871        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4872    )?;
4873    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4874    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4875    writeln!(code, "        let p: u32 = pos as u32;")?;
4876    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4877    writeln!(
4878        code,
4879        "        let total_pairs = num_heads * (head_dim / 2);"
4880    )?;
4881    writeln!(
4882        code,
4883        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4884    )?;
4885    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4886    writeln!(
4887        code,
4888        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4889    )?;
4890    writeln!(
4891        code,
4892        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4893    )?;
4894    writeln!(
4895        code,
4896        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4897    )?;
4898    writeln!(
4899        code,
4900        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4901    )?;
4902    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4903    writeln!(
4904        code,
4905        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4906    )?;
4907    writeln!(
4908        code,
4909        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4910    )?;
4911    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4912    writeln!(code, "    }}")?;
4913    writeln!(code)?;
4914
4915    // dispatch_rope_offset
4916    writeln!(
4917        code,
4918        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4919    )?;
4920    writeln!(
4921        code,
4922        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4923    )?;
4924    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4925    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4926    writeln!(code, "        let p: u32 = pos as u32;")?;
4927    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4928    writeln!(
4929        code,
4930        "        let total_pairs = num_heads * (head_dim / 2);"
4931    )?;
4932    writeln!(
4933        code,
4934        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4935    )?;
4936    writeln!(
4937        code,
4938        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4939    )?;
4940    writeln!(
4941        code,
4942        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4943    )?;
4944    writeln!(
4945        code,
4946        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4947    )?;
4948    writeln!(
4949        code,
4950        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4951    )?;
4952    writeln!(
4953        code,
4954        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4955    )?;
4956    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4957    writeln!(
4958        code,
4959        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4960    )?;
4961    writeln!(
4962        code,
4963        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4964    )?;
4965    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4966    writeln!(code, "    }}")?;
4967    writeln!(code)?;
4968
4969    // dispatch_attention
4970    writeln!(
4971        code,
4972        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4973    )?;
4974    writeln!(
4975        code,
4976        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4977    )?;
4978    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4979    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4980    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4981    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4982    writeln!(
4983        code,
4984        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4985    )?;
4986    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4987    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4988    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4989    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4990    writeln!(
4991        code,
4992        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4993    )?;
4994    writeln!(
4995        code,
4996        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4997    )?;
4998    writeln!(
4999        code,
5000        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5001    )?;
5002    writeln!(
5003        code,
5004        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5005    )?;
5006    writeln!(
5007        code,
5008        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
5009    )?;
5010    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5011    writeln!(
5012        code,
5013        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5014    )?;
5015    writeln!(
5016        code,
5017        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5018    )?;
5019    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5020    writeln!(code, "    }}")?;
5021    writeln!(code)?;
5022
5023    // dispatch_attention_offset
5024    writeln!(
5025        code,
5026        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
5027    )?;
5028    writeln!(
5029        code,
5030        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
5031    )?;
5032    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
5033    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5034    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5035    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5036    writeln!(
5037        code,
5038        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
5039    )?;
5040    writeln!(
5041        code,
5042        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5043    )?;
5044    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5045    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5046    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5047    writeln!(
5048        code,
5049        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5050    )?;
5051    writeln!(
5052        code,
5053        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5054    )?;
5055    writeln!(
5056        code,
5057        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5058    )?;
5059    writeln!(
5060        code,
5061        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5062    )?;
5063    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5064    writeln!(
5065        code,
5066        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5067    )?;
5068    writeln!(
5069        code,
5070        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5071    )?;
5072    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5073    writeln!(code, "    }}")?;
5074    writeln!(code)?;
5075
5076    // dispatch_silu_mul
5077    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
5078    writeln!(
5079        code,
5080        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5081    )?;
5082    writeln!(code, "        let count: u32 = n as u32;")?;
5083    writeln!(
5084        code,
5085        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5086    )?;
5087    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
5088    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
5089    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5090    writeln!(
5091        code,
5092        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5093    )?;
5094    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5095    writeln!(
5096        code,
5097        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5098    )?;
5099    writeln!(
5100        code,
5101        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5102    )?;
5103    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5104    writeln!(code, "    }}")?;
5105    writeln!(code)?;
5106
5107    // dispatch_silu_mul_fused
5108    writeln!(
5109        code,
5110        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5111    )?;
5112    writeln!(
5113        code,
5114        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
5115    )?;
5116    writeln!(
5117        code,
5118        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5119    )?;
5120    writeln!(code, "        let count: u32 = n as u32;")?;
5121    writeln!(
5122        code,
5123        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5124    )?;
5125    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5126    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5127    writeln!(
5128        code,
5129        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5130    )?;
5131    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5132    writeln!(
5133        code,
5134        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5135    )?;
5136    writeln!(
5137        code,
5138        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5139    )?;
5140    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5141    writeln!(code, "    }}")?;
5142    writeln!(code)?;
5143
5144    // dispatch_copy (simple src -> dst copy via compute kernel)
5145    writeln!(
5146        code,
5147        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5148    )?;
5149    writeln!(
5150        code,
5151        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5152    )?;
5153    writeln!(code, "        let n: u32 = count as u32;")?;
5154    writeln!(
5155        code,
5156        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5157    )?;
5158    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5159    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5160    writeln!(
5161        code,
5162        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5163    )?;
5164    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5165    writeln!(
5166        code,
5167        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5168    )?;
5169    writeln!(
5170        code,
5171        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5172    )?;
5173    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5174    writeln!(code, "    }}")?;
5175    writeln!(code)?;
5176
5177    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
5178    writeln!(
5179        code,
5180        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5181    )?;
5182    writeln!(
5183        code,
5184        "    /// Used for embedding table lookup (copy a specific row)."
5185    )?;
5186    writeln!(
5187        code,
5188        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5189    )?;
5190    writeln!(code, "        let off: u32 = src_offset as u32;")?;
5191    writeln!(code, "        let n: u32 = count as u32;")?;
5192    writeln!(
5193        code,
5194        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5195    )?;
5196    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5197    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5198    writeln!(
5199        code,
5200        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5201    )?;
5202    writeln!(
5203        code,
5204        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5205    )?;
5206    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5207    writeln!(
5208        code,
5209        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5210    )?;
5211    writeln!(
5212        code,
5213        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5214    )?;
5215    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5216    writeln!(code, "    }}")?;
5217    writeln!(code)?;
5218
5219    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
5220    writeln!(
5221        code,
5222        "    /// Dispatch copy from source at byte offset to destination at float offset."
5223    )?;
5224    writeln!(
5225        code,
5226        "    /// Used for KV cache updates from fused QKV buffer."
5227    )?;
5228    writeln!(
5229        code,
5230        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5231    )?;
5232    writeln!(code, "        let n: u32 = count as u32;")?;
5233    writeln!(
5234        code,
5235        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5236    )?;
5237    writeln!(
5238        code,
5239        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5240    )?;
5241    writeln!(
5242        code,
5243        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5244    )?;
5245    writeln!(
5246        code,
5247        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5248    )?;
5249    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5250    writeln!(
5251        code,
5252        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5253    )?;
5254    writeln!(
5255        code,
5256        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5257    )?;
5258    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5259    writeln!(code, "    }}")?;
5260    writeln!(code)?;
5261
5262    // dispatch_copy_from_offset_f16 (copy src[src_byte_offset..] -> f16 dst[dst_elem_offset..])
5263    writeln!(
5264        code,
5265        "    /// Dispatch f32 -> f16 copy from src at byte offset to half-typed dst at element offset."
5266    )?;
5267    writeln!(
5268        code,
5269        "    /// Used for single-token decode KV cache updates (f32 QKV buf -> f16 KV cache)."
5270    )?;
5271    writeln!(
5272        code,
5273        "    fn dispatch_copy_from_offset_f16(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_elem_offset: usize, count: usize) {{"
5274    )?;
5275    writeln!(code, "        let n: u32 = count as u32;")?;
5276    writeln!(
5277        code,
5278        "        enc.set_compute_pipeline_state(&self.copy_f32_to_f16_offset_pipeline);"
5279    )?;
5280    writeln!(
5281        code,
5282        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5283    )?;
5284    writeln!(
5285        code,
5286        "        enc.set_buffer(1, Some(dst), (dst_elem_offset * 2) as u64);"
5287    )?;
5288    writeln!(
5289        code,
5290        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5291    )?;
5292    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5293    writeln!(
5294        code,
5295        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5296    )?;
5297    writeln!(
5298        code,
5299        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5300    )?;
5301    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5302    writeln!(code, "    }}")?;
5303    writeln!(code)?;
5304
5305    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
5306    writeln!(
5307        code,
5308        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5309    )?;
5310    writeln!(
5311        code,
5312        "    /// Used for KV cache updates (write to a specific position in the cache)."
5313    )?;
5314    writeln!(
5315        code,
5316        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5317    )?;
5318    writeln!(code, "        let n: u32 = count as u32;")?;
5319    writeln!(
5320        code,
5321        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5322    )?;
5323    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5324    writeln!(
5325        code,
5326        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5327    )?;
5328    writeln!(
5329        code,
5330        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5331    )?;
5332    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5333    writeln!(
5334        code,
5335        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5336    )?;
5337    writeln!(
5338        code,
5339        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5340    )?;
5341    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5342    writeln!(code, "    }}")?;
5343    writeln!(code)?;
5344
5345    // dispatch_add_inplace (residual connection, no blit needed)
5346    writeln!(
5347        code,
5348        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5349    )?;
5350    writeln!(
5351        code,
5352        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5353    )?;
5354    writeln!(code, "        let count: u32 = n as u32;")?;
5355    writeln!(
5356        code,
5357        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5358    )?;
5359    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5360    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5361    writeln!(
5362        code,
5363        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5364    )?;
5365    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5366    writeln!(
5367        code,
5368        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5369    )?;
5370    writeln!(
5371        code,
5372        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5373    )?;
5374    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5375    writeln!(code, "    }}")?;
5376    writeln!(code)?;
5377
5378    // ── Batched prefill dispatch helpers ──
5379    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
5380    writeln!(code)?;
5381
5382    // dispatch_copy_embedding_batch
5383    writeln!(
5384        code,
5385        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
5386    )?;
5387    writeln!(
5388        code,
5389        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5390    )?;
5391    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
5392    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5393    writeln!(
5394        code,
5395        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5396    )?;
5397    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5398    writeln!(
5399        code,
5400        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5401    )?;
5402    writeln!(
5403        code,
5404        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5405    )?;
5406    writeln!(
5407        code,
5408        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5409    )?;
5410    writeln!(
5411        code,
5412        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5413    )?;
5414    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
5415    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5416    writeln!(
5417        code,
5418        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5419    )?;
5420    writeln!(
5421        code,
5422        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5423    )?;
5424    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5425    writeln!(code, "    }}")?;
5426    writeln!(code)?;
5427
5428    // dispatch_rms_norm_batch
5429    writeln!(
5430        code,
5431        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
5432    )?;
5433    writeln!(
5434        code,
5435        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5436    )?;
5437    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
5438    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
5439    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5440    writeln!(
5441        code,
5442        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5443    )?;
5444    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
5445    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
5446    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5447    writeln!(
5448        code,
5449        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5450    )?;
5451    writeln!(
5452        code,
5453        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5454    )?;
5455    writeln!(
5456        code,
5457        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5458    )?;
5459    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5460    writeln!(
5461        code,
5462        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5463    )?;
5464    writeln!(
5465        code,
5466        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5467    )?;
5468    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5469    writeln!(code, "    }}")?;
5470    writeln!(code)?;
5471
5472    // dispatch_matmul_batch (f32)
5473    writeln!(
5474        code,
5475        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5476    )?;
5477    writeln!(
5478        code,
5479        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5480    )?;
5481    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5482    writeln!(code, "        let r: u32 = rows as u32;")?;
5483    writeln!(code, "        let c: u32 = cols as u32;")?;
5484    writeln!(
5485        code,
5486        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5487    )?;
5488    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5489    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5490    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5491    writeln!(
5492        code,
5493        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5494    )?;
5495    writeln!(
5496        code,
5497        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5498    )?;
5499    writeln!(
5500        code,
5501        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5502    )?;
5503    writeln!(
5504        code,
5505        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
5506    )?;
5507    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5508    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5509    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5510    writeln!(
5511        code,
5512        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5513    )?;
5514    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5515    writeln!(code, "    }}")?;
5516    writeln!(code)?;
5517
5518    // dispatch_matmul_q8_batch
5519    writeln!(
5520        code,
5521        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5522    )?;
5523    writeln!(code, "    ///")?;
5524    writeln!(
5525        code,
5526        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5527    )?;
5528    writeln!(
5529        code,
5530        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5531    )?;
5532    writeln!(
5533        code,
5534        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5535    )?;
5536    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5537    writeln!(code, "        let r: u32 = rows as u32;")?;
5538    writeln!(code, "        let c: u32 = cols as u32;")?;
5539    writeln!(
5540        code,
5541        "        // Tile sizes must match the Metal shader constants."
5542    )?;
5543    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
5544    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
5545    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
5546    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
5547    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
5548    writeln!(
5549        code,
5550        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
5551    )?;
5552    writeln!(
5553        code,
5554        "        // Prefer the large 32×32 tile when the problem supports it — halves"
5555    )?;
5556    writeln!(
5557        code,
5558        "        // dispatch count and reuses each weight load across 32 tokens."
5559    )?;
5560    writeln!(
5561        code,
5562        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5563    )?;
5564    writeln!(
5565        code,
5566        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5567    )?;
5568    writeln!(
5569        code,
5570        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
5571    )?;
5572    writeln!(
5573        code,
5574        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5575    )?;
5576    writeln!(
5577        code,
5578        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
5579    )?;
5580    writeln!(
5581        code,
5582        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5583    )?;
5584    writeln!(
5585        code,
5586        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5587    )?;
5588    writeln!(
5589        code,
5590        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5591    )?;
5592    writeln!(
5593        code,
5594        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5595    )?;
5596    writeln!(code, "            let use_h4 = cols >= 2048;")?;
5597    writeln!(code, "            let pipe = if use_h4 {{")?;
5598    writeln!(code, "                if num_tokens >= 256 {{")?;
5599    writeln!(
5600        code,
5601        "                    &self.matmul_q8_mma32_hh4_pipeline"
5602    )?;
5603    writeln!(code, "                }} else {{")?;
5604    writeln!(
5605        code,
5606        "                    &self.matmul_q8_mma32_h4_pipeline"
5607    )?;
5608    writeln!(code, "                }}")?;
5609    writeln!(code, "            }} else {{")?;
5610    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
5611    writeln!(code, "            }};")?;
5612    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
5613    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5614    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5615    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5616    writeln!(
5617        code,
5618        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5619    )?;
5620    writeln!(
5621        code,
5622        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5623    )?;
5624    writeln!(
5625        code,
5626        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5627    )?;
5628    writeln!(code, "            let row_tgs = rows / MMA32_ROW_TILE;")?;
5629    writeln!(
5630        code,
5631        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5632    )?;
5633    writeln!(
5634        code,
5635        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5636    )?;
5637    writeln!(
5638        code,
5639        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5640    )?;
5641    writeln!(
5642        code,
5643        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5644    )?;
5645    writeln!(
5646        code,
5647        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5648    )?;
5649    writeln!(
5650        code,
5651        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5652    )?;
5653    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5654    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5655    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5656    writeln!(
5657        code,
5658        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5659    )?;
5660    writeln!(
5661        code,
5662        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5663    )?;
5664    writeln!(
5665        code,
5666        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5667    )?;
5668    writeln!(code, "            let row_tgs = rows / MMA_ROW_TILE;")?;
5669    writeln!(
5670        code,
5671        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5672    )?;
5673    writeln!(
5674        code,
5675        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5676    )?;
5677    writeln!(
5678        code,
5679        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5680    )?;
5681    writeln!(
5682        code,
5683        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5684    )?;
5685    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5686    writeln!(
5687        code,
5688        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5689    )?;
5690    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5691    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5692    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5693    writeln!(
5694        code,
5695        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5696    )?;
5697    writeln!(
5698        code,
5699        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5700    )?;
5701    writeln!(
5702        code,
5703        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5704    )?;
5705    writeln!(
5706        code,
5707        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5708    )?;
5709    writeln!(
5710        code,
5711        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5712    )?;
5713    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5714    writeln!(
5715        code,
5716        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5717    )?;
5718    writeln!(
5719        code,
5720        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5721    )?;
5722    writeln!(code, "        }} else {{")?;
5723    writeln!(
5724        code,
5725        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5726    )?;
5727    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5728    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5729    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5730    writeln!(
5731        code,
5732        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5733    )?;
5734    writeln!(
5735        code,
5736        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5737    )?;
5738    writeln!(
5739        code,
5740        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5741    )?;
5742    writeln!(
5743        code,
5744        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5745    )?;
5746    writeln!(
5747        code,
5748        "            let num_tg = (row_tgs * num_tokens) as u64;"
5749    )?;
5750    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5751    writeln!(
5752        code,
5753        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5754    )?;
5755    writeln!(
5756        code,
5757        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5758    )?;
5759    writeln!(code, "        }}")?;
5760    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5761    writeln!(code, "    }}")?;
5762    writeln!(code)?;
5763
5764    // dispatch_matmul_q4_batch
5765    writeln!(
5766        code,
5767        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5768    )?;
5769    writeln!(
5770        code,
5771        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5772    )?;
5773    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5774    writeln!(code, "        let r: u32 = rows as u32;")?;
5775    writeln!(code, "        let c: u32 = cols as u32;")?;
5776    writeln!(
5777        code,
5778        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5779    )?;
5780    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5781    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5782    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5783    writeln!(
5784        code,
5785        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5786    )?;
5787    writeln!(
5788        code,
5789        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5790    )?;
5791    writeln!(
5792        code,
5793        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5794    )?;
5795    writeln!(
5796        code,
5797        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5798    )?;
5799    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5800    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5801    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5802    writeln!(
5803        code,
5804        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5805    )?;
5806    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5807    writeln!(code, "    }}")?;
5808    writeln!(code)?;
5809
5810    // dispatch_add_bias_batch — Qwen2 QKV bias broadcast-add after fused qkv matmul.
5811    if config.qkv_bias {
5812        writeln!(
5813            code,
5814            "    /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5815        )?;
5816        writeln!(
5817            code,
5818            "    fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5819        )?;
5820        writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5821        writeln!(code, "        let r: u32 = rows as u32;")?;
5822        writeln!(
5823            code,
5824            "        enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5825        )?;
5826        writeln!(code, "        enc.set_buffer(0, Some(out), 0);")?;
5827        writeln!(code, "        enc.set_buffer(1, Some(bias), 0);")?;
5828        writeln!(
5829            code,
5830            "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5831        )?;
5832        writeln!(
5833            code,
5834            "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5835        )?;
5836        writeln!(code, "        let total = (num_tokens * rows) as u64;")?;
5837        writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5838        writeln!(
5839            code,
5840            "        let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5841        )?;
5842        writeln!(
5843            code,
5844            "        enc.dispatch_thread_groups(grid_size, tg_size);"
5845        )?;
5846        writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5847        writeln!(code, "    }}")?;
5848        writeln!(code)?;
5849    }
5850
5851    // dispatch_rope_batch
5852    writeln!(
5853        code,
5854        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5855    )?;
5856    writeln!(
5857        code,
5858        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5859    )?;
5860    writeln!(
5861        code,
5862        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5863    )?;
5864    writeln!(
5865        code,
5866        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5867    )?;
5868    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5869    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5870    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5871    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5872    writeln!(
5873        code,
5874        "        let pairs_per_token = num_heads * (head_dim / 2);"
5875    )?;
5876    writeln!(
5877        code,
5878        "        let total_pairs = num_tokens * pairs_per_token;"
5879    )?;
5880    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5881    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5882    // we need to pass the buffer at the right byte offset for each token's data.
5883    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5884    // but our layout is data[token * row_stride + data_float_offset + ...].
5885    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5886    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5887    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5888    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5889    //
5890    // Simplest approach: use the single-token rope kernel for each token in a loop.
5891    // This is still efficient because we're dispatching all within the same command encoder.
5892    writeln!(
5893        code,
5894        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5895    )?;
5896    writeln!(code, "        for t in 0..num_tokens {{")?;
5897    writeln!(
5898        code,
5899        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5900    )?;
5901    writeln!(
5902        code,
5903        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5904    )?;
5905    writeln!(
5906        code,
5907        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5908    )?;
5909    writeln!(
5910        code,
5911        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5912    )?;
5913    writeln!(
5914        code,
5915        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5916    )?;
5917    writeln!(
5918        code,
5919        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5920    )?;
5921    writeln!(
5922        code,
5923        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5924    )?;
5925    writeln!(
5926        code,
5927        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5928    )?;
5929    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5930    writeln!(
5931        code,
5932        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5933    )?;
5934    writeln!(
5935        code,
5936        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5937    )?;
5938    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5939    writeln!(code, "        }}")?;
5940    writeln!(code, "    }}")?;
5941    writeln!(code)?;
5942
5943    // dispatch_silu_mul_fused_batch
5944    writeln!(
5945        code,
5946        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5947    )?;
5948    writeln!(
5949        code,
5950        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5951    )?;
5952    writeln!(code, "        let count: u32 = n as u32;")?;
5953    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5954    writeln!(
5955        code,
5956        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5957    )?;
5958    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5959    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5960    writeln!(
5961        code,
5962        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5963    )?;
5964    writeln!(
5965        code,
5966        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5967    )?;
5968    writeln!(code, "        let total = num_tokens * n;")?;
5969    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5970    writeln!(
5971        code,
5972        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5973    )?;
5974    writeln!(
5975        code,
5976        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5977    )?;
5978    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5979    writeln!(code, "    }}")?;
5980    writeln!(code)?;
5981
5982    // dispatch_add_inplace_batch_n (add n elements in-place)
5983    writeln!(
5984        code,
5985        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5986    )?;
5987    writeln!(
5988        code,
5989        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5990    )?;
5991    writeln!(code, "        let count: u32 = total_n as u32;")?;
5992    writeln!(
5993        code,
5994        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5995    )?;
5996    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5997    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5998    writeln!(
5999        code,
6000        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
6001    )?;
6002    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6003    writeln!(
6004        code,
6005        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
6006    )?;
6007    writeln!(
6008        code,
6009        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6010    )?;
6011    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6012    writeln!(code, "    }}")?;
6013    writeln!(code)?;
6014
6015    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
6016    writeln!(
6017        code,
6018        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
6019    )?;
6020    writeln!(
6021        code,
6022        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
6023    )?;
6024    writeln!(code, "        let n: u32 = count as u32;")?;
6025    writeln!(
6026        code,
6027        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6028    )?;
6029    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6030    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
6031    writeln!(
6032        code,
6033        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6034    )?;
6035    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6036    writeln!(
6037        code,
6038        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6039    )?;
6040    writeln!(
6041        code,
6042        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6043    )?;
6044    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6045    writeln!(code, "    }}")?;
6046    writeln!(code)?;
6047
6048    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
6049    writeln!(
6050        code,
6051        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
6052    )?;
6053    writeln!(
6054        code,
6055        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6056    )?;
6057    writeln!(code, "        let n: u32 = count as u32;")?;
6058    writeln!(
6059        code,
6060        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6061    )?;
6062    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6063    writeln!(
6064        code,
6065        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6066    )?;
6067    writeln!(
6068        code,
6069        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6070    )?;
6071    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6072    writeln!(
6073        code,
6074        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6075    )?;
6076    writeln!(
6077        code,
6078        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6079    )?;
6080    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6081    writeln!(code, "    }}")?;
6082    writeln!(code)?;
6083
6084    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
6085    writeln!(
6086        code,
6087        "    /// Copy from src at byte offset to dst at float offset."
6088    )?;
6089    writeln!(
6090        code,
6091        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6092    )?;
6093    writeln!(code, "        let n: u32 = count as u32;")?;
6094    writeln!(
6095        code,
6096        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6097    )?;
6098    writeln!(
6099        code,
6100        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6101    )?;
6102    writeln!(
6103        code,
6104        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6105    )?;
6106    writeln!(
6107        code,
6108        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6109    )?;
6110    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6111    writeln!(
6112        code,
6113        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6114    )?;
6115    writeln!(
6116        code,
6117        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6118    )?;
6119    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6120    writeln!(code, "    }}")?;
6121    writeln!(code)?;
6122
6123    // dispatch_copy_kv_batch
6124    writeln!(
6125        code,
6126        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6127    )?;
6128    writeln!(
6129        code,
6130        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6131    )?;
6132    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6133    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6134    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6135    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6136    writeln!(code, "        let so: u32 = src_offset as u32;")?;
6137    writeln!(
6138        code,
6139        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6140    )?;
6141    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6142    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
6143    writeln!(
6144        code,
6145        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6146    )?;
6147    writeln!(
6148        code,
6149        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6150    )?;
6151    writeln!(
6152        code,
6153        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6154    )?;
6155    writeln!(
6156        code,
6157        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6158    )?;
6159    writeln!(
6160        code,
6161        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6162    )?;
6163    writeln!(code, "        let total = num_tokens * kv_dim;")?;
6164    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6165    writeln!(
6166        code,
6167        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6168    )?;
6169    writeln!(
6170        code,
6171        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6172    )?;
6173    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6174    writeln!(code, "    }}")?;
6175    writeln!(code)?;
6176
6177    // dispatch_attention_batch
6178    writeln!(
6179        code,
6180        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
6181    )?;
6182    writeln!(
6183        code,
6184        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6185    )?;
6186    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6187    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6188    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
6189    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6190    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6191    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
6192    // Attention kernel selection:
6193    //   * Legacy `attention_batch` materializes scores[4096] in threadgroup memory
6194    //     and uses scalar simdgroup reductions.  Fast at short seq_len, no MMA.
6195    //   * `attention_flash_batch` streams K/V with online softmax; no seq cap,
6196    //     scalar math, ~7-14 % slower than legacy at long contexts (no MMA).
6197    //   * `attention_mma_flash_batch` adds hardware simdgroup_matrix<half, 8, 8>
6198    //     MMA for both Q·K^T and P·V, processing Q_BLOCK=8 tokens per TG.
6199    //     Default path when HEAD_DIM ≤ 128 and num_tokens ≥ 8 (verified on
6200    //     Llama, Qwen2.5, Phi-3).  Set FORGE_MMA_ATTN=0 to force legacy.
6201    writeln!(code, "        let max_seq = base_pos + num_tokens;")?;
6202    writeln!(code, "        let _ = max_seq;")?;
6203    writeln!(
6204        code,
6205        "        let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6206    )?;
6207    writeln!(code, "            .map(|v| v == \"0\").unwrap_or(false);")?;
6208    writeln!(
6209        code,
6210        "        let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6211    )?;
6212    writeln!(code, "        if use_mma_flash {{")?;
6213    writeln!(
6214        code,
6215        "            let pipe = &self.attention_mma_flash_batch_pipeline;"
6216    )?;
6217    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
6218    writeln!(code, "            enc.set_buffer(0, Some(q_buf), 0);")?;
6219    writeln!(code, "            enc.set_buffer(1, Some(k_cache), 0);")?;
6220    writeln!(code, "            enc.set_buffer(2, Some(v_cache), 0);")?;
6221    writeln!(code, "            enc.set_buffer(3, Some(output), 0);")?;
6222    writeln!(
6223        code,
6224        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6225    )?;
6226    writeln!(
6227        code,
6228        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6229    )?;
6230    writeln!(
6231        code,
6232        "            enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6233    )?;
6234    writeln!(
6235        code,
6236        "            enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6237    )?;
6238    writeln!(
6239        code,
6240        "            enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6241    )?;
6242    writeln!(
6243        code,
6244        "            enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6245    )?;
6246    writeln!(
6247        code,
6248        "            // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6249    )?;
6250    writeln!(code, "            let tg_size = MTLSize::new(128, 1, 1);")?;
6251    writeln!(
6252        code,
6253        "            let q_blocks = ((num_tokens + 7) / 8) as u64;"
6254    )?;
6255    writeln!(
6256        code,
6257        "            let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6258    )?;
6259    writeln!(
6260        code,
6261        "            enc.dispatch_thread_groups(grid_size, tg_size);"
6262    )?;
6263    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6264    writeln!(code, "            return;")?;
6265    writeln!(code, "        }}")?;
6266    writeln!(code, "        let pipe = &self.attention_batch_pipeline;")?;
6267    writeln!(code, "        enc.set_compute_pipeline_state(pipe);")?;
6268    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
6269    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
6270    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
6271    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
6272    writeln!(
6273        code,
6274        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6275    )?;
6276    writeln!(
6277        code,
6278        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6279    )?;
6280    writeln!(
6281        code,
6282        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6283    )?;
6284    writeln!(
6285        code,
6286        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6287    )?;
6288    writeln!(
6289        code,
6290        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6291    )?;
6292    writeln!(
6293        code,
6294        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6295    )?;
6296    writeln!(
6297        code,
6298        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6299    )?;
6300    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6301    writeln!(
6302        code,
6303        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6304    )?;
6305    writeln!(
6306        code,
6307        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6308    )?;
6309    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6310    writeln!(code, "    }}")?;
6311    writeln!(code)?;
6312
6313    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
6314    writeln!(
6315        code,
6316        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6317    )?;
6318    writeln!(
6319        code,
6320        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6321    )?;
6322    writeln!(
6323        code,
6324        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6325    )?;
6326    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6327    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6328    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
6329    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6330    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6331    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
6332    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
6333    writeln!(
6334        code,
6335        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6336    )?;
6337    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
6338    writeln!(
6339        code,
6340        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6341    )?;
6342    writeln!(
6343        code,
6344        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6345    )?;
6346    writeln!(
6347        code,
6348        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6349    )?;
6350    writeln!(
6351        code,
6352        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6353    )?;
6354    writeln!(
6355        code,
6356        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6357    )?;
6358    writeln!(
6359        code,
6360        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6361    )?;
6362    writeln!(
6363        code,
6364        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6365    )?;
6366    writeln!(
6367        code,
6368        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6369    )?;
6370    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6371    writeln!(
6372        code,
6373        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6374    )?;
6375    writeln!(
6376        code,
6377        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6378    )?;
6379    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6380    writeln!(code, "    }}")?;
6381    writeln!(code)?;
6382
6383    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
6384    writeln!(
6385        code,
6386        "    /// Dispatch fused K+V cache copy in one kernel launch."
6387    )?;
6388    writeln!(
6389        code,
6390        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6391    )?;
6392    writeln!(
6393        code,
6394        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6395    )?;
6396    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6397    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6398    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6399    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6400    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
6401    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
6402    writeln!(
6403        code,
6404        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6405    )?;
6406    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6407    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
6408    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
6409    writeln!(
6410        code,
6411        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6412    )?;
6413    writeln!(
6414        code,
6415        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6416    )?;
6417    writeln!(
6418        code,
6419        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6420    )?;
6421    writeln!(
6422        code,
6423        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6424    )?;
6425    writeln!(
6426        code,
6427        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6428    )?;
6429    writeln!(
6430        code,
6431        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6432    )?;
6433    writeln!(
6434        code,
6435        "        let total = num_tokens * kv_dim * 2;  // K + V"
6436    )?;
6437    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6438    writeln!(
6439        code,
6440        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6441    )?;
6442    writeln!(
6443        code,
6444        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6445    )?;
6446    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6447    writeln!(code, "    }}")?;
6448
6449    writeln!(code, "}}")?;
6450    writeln!(code)?;
6451
6452    Ok(())
6453}
6454
6455fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6456    writeln!(
6457        code,
6458        "// ── Helper functions ──────────────────────────────────"
6459    )?;
6460    writeln!(code)?;
6461    writeln!(
6462        code,
6463        "/// Create a compute pipeline from a named function in the Metal library."
6464    )?;
6465    writeln!(
6466        code,
6467        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6468    )?;
6469    writeln!(
6470        code,
6471        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6472    )?;
6473    writeln!(
6474        code,
6475        "    device.new_compute_pipeline_state_with_function(&func)"
6476    )?;
6477    writeln!(
6478        code,
6479        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6480    )?;
6481    writeln!(code, "}}")?;
6482    writeln!(code)?;
6483
6484    Ok(())
6485}
6486
6487// ---------------------------------------------------------------------------
6488// main.rs generation
6489// ---------------------------------------------------------------------------
6490
6491fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6492    let _sanitized = sanitize_name(model_name);
6493    let _vocab = config.vocab_size;
6494
6495    let mut code = String::with_capacity(16 * 1024);
6496    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6497    writeln!(
6498        code,
6499        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6500    )?;
6501    writeln!(code)?;
6502    writeln!(code, "mod model;")?;
6503    writeln!(code)?;
6504    writeln!(code, "use std::io::Write;")?;
6505    writeln!(code, "use std::time::Instant;")?;
6506    writeln!(code, "use serde::Deserialize;")?;
6507    writeln!(code)?;
6508
6509    // -- main function --
6510    writeln!(code, "fn main() {{")?;
6511    writeln!(
6512        code,
6513        "    let args: Vec<String> = std::env::args().collect();"
6514    )?;
6515    writeln!(code)?;
6516    writeln!(
6517        code,
6518        "    // Detect --serve mode (only requires weights + tokenizer)"
6519    )?;
6520    writeln!(
6521        code,
6522        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
6523    )?;
6524    writeln!(code)?;
6525    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
6526    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6527    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6528    writeln!(code, "        std::process::exit(1);")?;
6529    writeln!(code, "    }}")?;
6530    writeln!(code)?;
6531    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
6532    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6533    writeln!(code, "        std::process::exit(1);")?;
6534    writeln!(code, "    }}")?;
6535    writeln!(code)?;
6536    writeln!(code, "    let weights_path = &args[1];")?;
6537    writeln!(code, "    let tokenizer_path = &args[2];")?;
6538    writeln!(code)?;
6539    writeln!(code, "    // Parse optional flags")?;
6540    writeln!(code, "    let mut max_tokens: usize = 128;")?;
6541    writeln!(code, "    let mut port: u16 = 8080;")?;
6542    writeln!(
6543        code,
6544        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6545    )?;
6546    writeln!(
6547        code,
6548        "    let profile = args.iter().any(|a| a == \"--profile\");"
6549    )?;
6550    writeln!(code, "    let mut i = 3;")?;
6551    writeln!(code, "    while i < args.len() {{")?;
6552    writeln!(
6553        code,
6554        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6555    )?;
6556    writeln!(
6557        code,
6558        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
6559    )?;
6560    writeln!(code, "            i += 2;")?;
6561    writeln!(
6562        code,
6563        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6564    )?;
6565    writeln!(
6566        code,
6567        "            port = args[i + 1].parse().unwrap_or(8080);"
6568    )?;
6569    writeln!(code, "            i += 2;")?;
6570    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
6571    writeln!(code, "            i += 1;")?;
6572    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
6573    writeln!(code, "            i += 1;")?;
6574    writeln!(code, "        }} else {{")?;
6575    writeln!(code, "            i += 1;")?;
6576    writeln!(code, "        }}")?;
6577    writeln!(code, "    }}")?;
6578    writeln!(code)?;
6579
6580    // -- load model (shared by both modes) --
6581    writeln!(
6582        code,
6583        "    // Memory-map weights for zero-copy loading on Apple Silicon"
6584    )?;
6585    writeln!(
6586        code,
6587        "    let weights_file = std::fs::File::open(weights_path)"
6588    )?;
6589    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6590    writeln!(
6591        code,
6592        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6593    )?;
6594    writeln!(code)?;
6595    writeln!(code, "    // Load tokenizer")?;
6596    writeln!(
6597        code,
6598        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6599    )?;
6600    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6601    writeln!(code)?;
6602    writeln!(code, "    // Create Metal model")?;
6603    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
6604    writeln!(
6605        code,
6606        "    let mut model = model::MetalModel::new(&weights_mmap);"
6607    )?;
6608    writeln!(code)?;
6609
6610    // -- branch: serve vs CLI --
6611    writeln!(code, "    if serve_mode {{")?;
6612    writeln!(code, "        serve(model, tokenizer, port);")?;
6613    writeln!(code, "    }} else {{")?;
6614    writeln!(code, "        let prompt = &args[3];")?;
6615    writeln!(
6616        code,
6617        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6618    )?;
6619    writeln!(code, "    }}")?;
6620    writeln!(code, "}}")?;
6621    writeln!(code)?;
6622
6623    // -- cli_mode function --
6624    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6625    writeln!(code, "    // Tokenize prompt")?;
6626    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
6627    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6628    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6629    writeln!(code)?;
6630    writeln!(
6631        code,
6632        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6633    )?;
6634    writeln!(
6635        code,
6636        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
6637    )?;
6638    writeln!(
6639        code,
6640        "    // The last token uses synchronous forward() to get logits."
6641    )?;
6642    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
6643    writeln!(code, "    let prefill_start = Instant::now();")?;
6644    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
6645    writeln!(
6646        code,
6647        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6648    )?;
6649    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
6650    writeln!(code, "    }} else {{")?;
6651    writeln!(code, "        model.forward(prompt_tokens[0])")?;
6652    writeln!(code, "    }};")?;
6653    writeln!(
6654        code,
6655        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6656    )?;
6657    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
6658    writeln!(
6659        code,
6660        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6661    )?;
6662    writeln!(
6663        code,
6664        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6665    )?;
6666    writeln!(code)?;
6667    writeln!(code, "    // Generate tokens")?;
6668    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6669    writeln!(code, "    let gen_start = Instant::now();")?;
6670    writeln!(code, "    let mut generated_count: usize = 0;")?;
6671    writeln!(code)?;
6672    writeln!(code, "    for _ in 0..max_tokens {{")?;
6673    writeln!(
6674        code,
6675        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6676    )?;
6677    writeln!(code, "            if !quiet {{")?;
6678    writeln!(code, "                print!(\"{{}}\", text);")?;
6679    writeln!(code, "                std::io::stdout().flush().ok();")?;
6680    writeln!(code, "            }}")?;
6681    writeln!(code, "        }}")?;
6682    writeln!(code, "        generated_count += 1;")?;
6683    writeln!(code)?;
6684    writeln!(
6685        code,
6686        "        // Use profiling forward for first token when --profile is set"
6687    )?;
6688    writeln!(
6689        code,
6690        "        let logits = if profile && generated_count == 1 {{"
6691    )?;
6692    writeln!(code, "            model.forward_profile(next_token)")?;
6693    writeln!(code, "        }} else {{")?;
6694    writeln!(code, "            model.forward(next_token)")?;
6695    writeln!(code, "        }};")?;
6696    writeln!(code, "        next_token = argmax(&logits);")?;
6697    writeln!(code)?;
6698    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
6699    writeln!(code, "        if next_token == 2 {{")?;
6700    writeln!(code, "            break;")?;
6701    writeln!(code, "        }}")?;
6702    writeln!(code)?;
6703    writeln!(
6704        code,
6705        "        // Yield between tokens to reduce sustained GPU thermal load."
6706    )?;
6707    writeln!(
6708        code,
6709        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6710    )?;
6711    writeln!(
6712        code,
6713        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6714    )?;
6715    writeln!(
6716        code,
6717        "        // briefly, providing a micro-break that helps sustain peak throughput."
6718    )?;
6719    writeln!(code, "        std::thread::yield_now();")?;
6720    writeln!(code, "    }}")?;
6721    writeln!(code, "    if !quiet {{")?;
6722    writeln!(code, "        println!();")?;
6723    writeln!(code, "    }}")?;
6724    writeln!(
6725        code,
6726        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6727    )?;
6728    writeln!(
6729        code,
6730        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6731    )?;
6732    writeln!(
6733        code,
6734        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6735    )?;
6736    writeln!(code, "}}")?;
6737    writeln!(code)?;
6738
6739    // -- argmax helper --
6740    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6741    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6742    writeln!(code, "    logits.iter()")?;
6743    writeln!(code, "        .enumerate()")?;
6744    writeln!(
6745        code,
6746        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6747    )?;
6748    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6749    writeln!(code, "        .unwrap_or(0)")?;
6750    writeln!(code, "}}")?;
6751    writeln!(code)?;
6752
6753    // -- Request/Response types for OpenAI API --
6754    writeln!(
6755        code,
6756        "// -----------------------------------------------------------------------"
6757    )?;
6758    writeln!(code, "// OpenAI-compatible API server")?;
6759    writeln!(
6760        code,
6761        "// -----------------------------------------------------------------------"
6762    )?;
6763    writeln!(code)?;
6764    writeln!(code, "#[derive(Deserialize)]")?;
6765    writeln!(code, "struct ChatRequest {{")?;
6766    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6767    writeln!(code, "    #[serde(default)]")?;
6768    writeln!(code, "    stream: Option<bool>,")?;
6769    writeln!(code, "    #[serde(default)]")?;
6770    writeln!(code, "    max_tokens: Option<usize>,")?;
6771    writeln!(code, "    #[serde(default)]")?;
6772    writeln!(code, "    temperature: Option<f32>,")?;
6773    writeln!(code, "    #[serde(default)]")?;
6774    writeln!(code, "    model: Option<String>,")?;
6775    writeln!(code, "}}")?;
6776    writeln!(code)?;
6777    writeln!(code, "#[derive(Deserialize)]")?;
6778    writeln!(code, "struct ChatMessage {{")?;
6779    writeln!(code, "    role: String,")?;
6780    writeln!(code, "    content: String,")?;
6781    writeln!(code, "}}")?;
6782    writeln!(code)?;
6783
6784    // -- format_chat_messages --
6785    writeln!(
6786        code,
6787        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6788    )?;
6789    writeln!(code, "    let mut prompt = String::new();")?;
6790    writeln!(code, "    for msg in messages {{")?;
6791    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6792    writeln!(code, "    }}")?;
6793    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6794    writeln!(code, "    prompt")?;
6795    writeln!(code, "}}")?;
6796    writeln!(code)?;
6797
6798    // -- prefill helper --
6799    writeln!(
6800        code,
6801        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6802    )?;
6803    writeln!(code, "    let len = tokens.len();")?;
6804    writeln!(code, "    if len > 1 {{")?;
6805    writeln!(
6806        code,
6807        "        model.forward_prefill_batch(&tokens[..len - 1]);"
6808    )?;
6809    writeln!(code, "    }}")?;
6810    writeln!(code, "    model.forward(tokens[len - 1])")?;
6811    writeln!(code, "}}")?;
6812    writeln!(code)?;
6813
6814    // -- serve function --
6815    writeln!(
6816        code,
6817        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6818    )?;
6819    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6820    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
6821    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6822    writeln!(
6823        code,
6824        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6825    )?;
6826    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
6827    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
6828    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
6829    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
6830    writeln!(code)?;
6831    writeln!(code, "    for request in server.incoming_requests() {{")?;
6832    writeln!(code, "        let method = request.method().to_string();")?;
6833    writeln!(code, "        let url = request.url().to_string();")?;
6834    writeln!(code)?;
6835    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
6836
6837    // -- POST /v1/chat/completions --
6838    writeln!(
6839        code,
6840        "            (\"POST\", \"/v1/chat/completions\") => {{"
6841    )?;
6842    writeln!(
6843        code,
6844        "                handle_chat_completion(&mut model, &tokenizer, request);"
6845    )?;
6846    writeln!(code, "            }}")?;
6847
6848    // -- GET /v1/models --
6849    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
6850    writeln!(code, "                let body = serde_json::json!({{")?;
6851    writeln!(code, "                    \"object\": \"list\",")?;
6852    writeln!(code, "                    \"data\": [{{")?;
6853    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
6854    writeln!(code, "                        \"object\": \"model\",")?;
6855    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
6856    writeln!(code, "                    }}]")?;
6857    writeln!(code, "                }});")?;
6858    writeln!(
6859        code,
6860        "                let resp = tiny_http::Response::from_string(body.to_string())"
6861    )?;
6862    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6863    writeln!(code, "                request.respond(resp).ok();")?;
6864    writeln!(code, "            }}")?;
6865
6866    // -- GET /health --
6867    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
6868    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6869    writeln!(code, "                request.respond(resp).ok();")?;
6870    writeln!(code, "            }}")?;
6871
6872    // -- 404 --
6873    writeln!(code, "            _ => {{")?;
6874    writeln!(
6875        code,
6876        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
6877    )?;
6878    writeln!(code, "                    .with_status_code(404);")?;
6879    writeln!(code, "                request.respond(resp).ok();")?;
6880    writeln!(code, "            }}")?;
6881    writeln!(code, "        }}")?;
6882    writeln!(code, "    }}")?;
6883    writeln!(code, "}}")?;
6884    writeln!(code)?;
6885
6886    // -- handle_chat_completion --
6887    writeln!(code, "fn handle_chat_completion(")?;
6888    writeln!(code, "    model: &mut model::MetalModel,")?;
6889    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6890    writeln!(code, "    mut request: tiny_http::Request,")?;
6891    writeln!(code, ") {{")?;
6892    writeln!(code, "    // Read request body")?;
6893    writeln!(code, "    let mut body = String::new();")?;
6894    writeln!(
6895        code,
6896        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6897    )?;
6898    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6899    writeln!(code, "            .with_status_code(400);")?;
6900    writeln!(code, "        request.respond(resp).ok();")?;
6901    writeln!(code, "        return;")?;
6902    writeln!(code, "    }}")?;
6903    writeln!(code)?;
6904    writeln!(code, "    // Parse JSON")?;
6905    writeln!(
6906        code,
6907        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6908    )?;
6909    writeln!(code, "        Ok(r) => r,")?;
6910    writeln!(code, "        Err(e) => {{")?;
6911    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6912    writeln!(code, "                .with_status_code(400);")?;
6913    writeln!(code, "            request.respond(resp).ok();")?;
6914    writeln!(code, "            return;")?;
6915    writeln!(code, "        }}")?;
6916    writeln!(code, "    }};")?;
6917    writeln!(code)?;
6918    writeln!(
6919        code,
6920        "    let prompt = format_chat_messages(&req.messages);"
6921    )?;
6922    writeln!(
6923        code,
6924        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6925    )?;
6926    writeln!(code, "        Ok(e) => e,")?;
6927    writeln!(code, "        Err(e) => {{")?;
6928    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6929    writeln!(code, "                .with_status_code(500);")?;
6930    writeln!(code, "            request.respond(resp).ok();")?;
6931    writeln!(code, "            return;")?;
6932    writeln!(code, "        }}")?;
6933    writeln!(code, "    }};")?;
6934    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6935    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6936    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6937    writeln!(
6938        code,
6939        "    let _temperature = req.temperature.unwrap_or(1.0);"
6940    )?;
6941    writeln!(code)?;
6942
6943    // -- Reset KV cache for each request --
6944    writeln!(code, "    model.reset();")?;
6945    writeln!(code)?;
6946
6947    // -- Prefill with timing --
6948    writeln!(code, "    let prefill_start = Instant::now();")?;
6949    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6950    writeln!(
6951        code,
6952        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6953    )?;
6954    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6955    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6956    writeln!(code)?;
6957
6958    writeln!(code, "    if stream {{")?;
6959
6960    // -- SSE streaming response --
6961    writeln!(
6962        code,
6963        "        // SSE streaming: generate tokens and build SSE body"
6964    )?;
6965    writeln!(code, "        let gen_start = Instant::now();")?;
6966    writeln!(code, "        let mut generated_count: usize = 0;")?;
6967    writeln!(code, "        let mut sse_body = String::new();")?;
6968    writeln!(code, "        for _ in 0..max_tokens {{")?;
6969    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6970    writeln!(
6971        code,
6972        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6973    )?;
6974    writeln!(
6975        code,
6976        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6977    )?;
6978    writeln!(
6979        code,
6980        "                // escaped includes surrounding quotes, strip them"
6981    )?;
6982    writeln!(
6983        code,
6984        "                let inner = &escaped[1..escaped.len()-1];"
6985    )?;
6986    writeln!(code, "                sse_body.push_str(&format!(")?;
6987    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6988    writeln!(code, "                    inner")?;
6989    writeln!(code, "                ));")?;
6990    writeln!(code, "            }}")?;
6991    writeln!(code, "            generated_count += 1;")?;
6992    writeln!(code, "            let logits = model.forward(next_token);")?;
6993    writeln!(code, "            next_token = argmax(&logits);")?;
6994    writeln!(code, "        }}")?;
6995    writeln!(
6996        code,
6997        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6998    )?;
6999    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7000    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
7001    writeln!(code)?;
7002    writeln!(
7003        code,
7004        "        // Final chunk with finish_reason, timing, and DONE sentinel"
7005    )?;
7006    writeln!(code, "        sse_body.push_str(&format!(")?;
7007    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
7008    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
7009    writeln!(code, "        ));")?;
7010    writeln!(code)?;
7011    writeln!(
7012        code,
7013        "        let resp = tiny_http::Response::from_string(sse_body)"
7014    )?;
7015    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
7016    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
7017    writeln!(code, "        request.respond(resp).ok();")?;
7018
7019    writeln!(code, "    }} else {{")?;
7020
7021    // -- Non-streaming response --
7022    writeln!(
7023        code,
7024        "        // Non-streaming: generate all tokens, return JSON"
7025    )?;
7026    writeln!(code, "        let gen_start = Instant::now();")?;
7027    writeln!(code, "        let mut generated_count: usize = 0;")?;
7028    writeln!(code, "        let mut generated = String::new();")?;
7029    writeln!(code, "        for _ in 0..max_tokens {{")?;
7030    writeln!(code, "            if next_token == 2 {{ break; }}")?;
7031    writeln!(
7032        code,
7033        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7034    )?;
7035    writeln!(code, "                generated.push_str(&text);")?;
7036    writeln!(code, "            }}")?;
7037    writeln!(code, "            generated_count += 1;")?;
7038    writeln!(code, "            let logits = model.forward(next_token);")?;
7039    writeln!(code, "            next_token = argmax(&logits);")?;
7040    writeln!(code, "        }}")?;
7041    writeln!(
7042        code,
7043        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7044    )?;
7045    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7046    writeln!(code)?;
7047    writeln!(code, "        let resp_json = serde_json::json!({{")?;
7048    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
7049    writeln!(code, "            \"object\": \"chat.completion\",")?;
7050    writeln!(code, "            \"choices\": [{{")?;
7051    writeln!(code, "                \"index\": 0,")?;
7052    writeln!(code, "                \"message\": {{")?;
7053    writeln!(code, "                    \"role\": \"assistant\",")?;
7054    writeln!(code, "                    \"content\": generated")?;
7055    writeln!(code, "                }},")?;
7056    writeln!(code, "                \"finish_reason\": \"stop\"")?;
7057    writeln!(code, "            }}],")?;
7058    writeln!(code, "            \"usage\": {{")?;
7059    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
7060    writeln!(
7061        code,
7062        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
7063    )?;
7064    writeln!(
7065        code,
7066        "                \"generation_tokens\": generated_count,"
7067    )?;
7068    writeln!(
7069        code,
7070        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
7071    )?;
7072    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
7073    writeln!(code, "            }}")?;
7074    writeln!(code, "        }});")?;
7075    writeln!(
7076        code,
7077        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
7078    )?;
7079    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7080    writeln!(code, "        request.respond(resp).ok();")?;
7081    writeln!(code, "    }}")?;
7082    writeln!(code, "}}")?;
7083
7084    Ok(code)
7085}
7086
7087// ---------------------------------------------------------------------------
7088// Tests
7089// ---------------------------------------------------------------------------
7090
7091#[cfg(test)]
7092mod tests {
7093    use super::*;
7094    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7095
7096    fn minimal_config() -> ModelConfig {
7097        ModelConfig {
7098            architecture: Architecture::Llama,
7099            hidden_size: 64,
7100            intermediate_size: 128,
7101            num_layers: 2,
7102            num_attention_heads: 4,
7103            num_kv_heads: 4,
7104            head_dim: 16,
7105            vocab_size: 256,
7106            max_seq_len: 512,
7107            rms_norm_eps: 1e-5,
7108            rope_theta: 10000.0,
7109            dtype: DType::F32,
7110            sliding_window_size: None,
7111            qkv_bias: false,
7112        }
7113    }
7114
7115    fn minimal_graph() -> Graph {
7116        Graph::new("test-metal").with_config(minimal_config())
7117    }
7118
7119    #[test]
7120    fn generate_metal_project_creates_files() {
7121        let dir = tempfile::tempdir().unwrap();
7122        let graph = minimal_graph();
7123        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7124
7125        assert!(
7126            dir.path().join("Cargo.toml").exists(),
7127            "Cargo.toml should be created"
7128        );
7129        assert!(
7130            dir.path().join("src/model.rs").exists(),
7131            "src/model.rs should be created"
7132        );
7133        assert!(
7134            dir.path().join("src/main.rs").exists(),
7135            "src/main.rs should be created"
7136        );
7137        assert!(
7138            dir.path().join("shaders/kernels.metal").exists(),
7139            "shaders/kernels.metal should be created"
7140        );
7141    }
7142
7143    #[test]
7144    fn generated_cargo_toml_has_metal_dep() {
7145        let toml = generate_cargo_toml("my-model");
7146        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7147        assert!(
7148            toml.contains("tokenizers"),
7149            "Cargo.toml should depend on tokenizers"
7150        );
7151        assert!(
7152            toml.contains("memmap2"),
7153            "Cargo.toml should depend on memmap2"
7154        );
7155        assert!(toml.contains("half"), "Cargo.toml should depend on half");
7156    }
7157
7158    #[test]
7159    fn generated_model_rs_contains_metal_code() {
7160        let config = minimal_config();
7161        let model_rs = generate_model_rs(&config).unwrap();
7162
7163        assert!(
7164            model_rs.contains("pub struct MetalModel"),
7165            "model.rs should define MetalModel struct"
7166        );
7167        assert!(
7168            model_rs.contains("matmul_pipeline: ComputePipelineState"),
7169            "MetalModel should have matmul_pipeline field"
7170        );
7171        assert!(
7172            model_rs.contains("Device::system_default()"),
7173            "model.rs should use Metal device"
7174        );
7175        assert!(
7176            model_rs.contains("new_library_with_source"),
7177            "model.rs should compile Metal shaders"
7178        );
7179        assert!(
7180            model_rs.contains("fn new(weights: &[u8])"),
7181            "MetalModel should implement new()"
7182        );
7183        assert!(
7184            model_rs.contains("fn forward(&mut self, token_id: u32)"),
7185            "MetalModel should implement forward()"
7186        );
7187    }
7188
7189    #[test]
7190    fn generated_shaders_contain_kernel_names() {
7191        let shaders = generate_metal_shaders(&minimal_config());
7192
7193        assert!(
7194            shaders.contains("kernel void matmul_vec"),
7195            "shaders should contain matmul_vec kernel"
7196        );
7197        assert!(
7198            shaders.contains("kernel void rms_norm"),
7199            "shaders should contain rms_norm kernel"
7200        );
7201        assert!(
7202            shaders.contains("kernel void rope"),
7203            "shaders should contain rope kernel"
7204        );
7205        assert!(
7206            shaders.contains("kernel void softmax"),
7207            "shaders should contain softmax kernel"
7208        );
7209        assert!(
7210            shaders.contains("kernel void silu_mul("),
7211            "shaders should contain silu_mul kernel"
7212        );
7213        assert!(
7214            shaders.contains("kernel void silu_mul_fused"),
7215            "shaders should contain silu_mul_fused kernel"
7216        );
7217        assert!(
7218            shaders.contains("kernel void elementwise_add"),
7219            "shaders should contain elementwise_add kernel"
7220        );
7221        assert!(
7222            shaders.contains("kernel void attention"),
7223            "shaders should contain attention kernel"
7224        );
7225        assert!(
7226            shaders.contains("kernel void add_inplace"),
7227            "shaders should contain add_inplace kernel"
7228        );
7229        assert!(
7230            shaders.contains("kernel void copy_buffer"),
7231            "shaders should contain copy_buffer kernel"
7232        );
7233        assert!(
7234            shaders.contains("kernel void copy_offset"),
7235            "shaders should contain copy_offset kernel"
7236        );
7237    }
7238
7239    #[test]
7240    fn generated_shaders_use_simdgroup_features() {
7241        let shaders = generate_metal_shaders(&minimal_config());
7242
7243        assert!(
7244            shaders.contains("threadgroup_barrier"),
7245            "shaders should use threadgroup barriers"
7246        );
7247        assert!(
7248            shaders.contains("threadgroup float"),
7249            "shaders should use threadgroup shared memory"
7250        );
7251        assert!(
7252            shaders.contains("thread_index_in_threadgroup"),
7253            "shaders should use threadgroup indexing"
7254        );
7255        assert!(
7256            shaders.contains("simd_sum"),
7257            "shaders should use simd_sum for warp-level reduction"
7258        );
7259        assert!(
7260            shaders.contains("simd_max"),
7261            "attention kernel should use simd_max for cooperative softmax"
7262        );
7263        assert!(
7264            shaders.contains("thread_index_in_simdgroup"),
7265            "shaders should use simdgroup lane indexing"
7266        );
7267        assert!(
7268            shaders.contains("simdgroup_index_in_threadgroup"),
7269            "shaders should use simdgroup indexing within threadgroup"
7270        );
7271        assert!(
7272            shaders.contains("float4"),
7273            "matmul_vec should use float4 vectorized loads"
7274        );
7275    }
7276
7277    #[test]
7278    fn generated_main_rs_has_tokenizer_usage() {
7279        let config = minimal_config();
7280        let main_rs = generate_main_rs("test-model", &config).unwrap();
7281
7282        assert!(
7283            main_rs.contains("tokenizers::Tokenizer"),
7284            "main.rs should use tokenizers crate"
7285        );
7286        assert!(
7287            main_rs.contains("MetalModel::new"),
7288            "main.rs should call MetalModel::new"
7289        );
7290        assert!(
7291            main_rs.contains("model.forward"),
7292            "main.rs should call model.forward"
7293        );
7294        assert!(
7295            main_rs.contains("memmap2"),
7296            "main.rs should use memmap2 for zero-copy weight loading"
7297        );
7298    }
7299
7300    #[test]
7301    fn missing_config_returns_error() {
7302        let dir = tempfile::tempdir().unwrap();
7303        let graph = Graph::new("no-config");
7304        let result = generate_metal_project(&graph, dir.path(), "fail");
7305        assert!(
7306            matches!(result, Err(MetalCodegenError::MissingConfig)),
7307            "should fail with MissingConfig when graph has no config"
7308        );
7309    }
7310
7311    #[test]
7312    fn sanitize_name_works() {
7313        assert_eq!(sanitize_name("My Model!"), "my-model");
7314        assert_eq!(sanitize_name("test_model"), "test-model");
7315        assert_eq!(sanitize_name("simple"), "simple");
7316    }
7317
7318    #[test]
7319    fn generated_forward_uses_single_command_buffer() {
7320        let config = minimal_config();
7321        let model_rs = generate_model_rs(&config).unwrap();
7322
7323        // The forward function should create exactly one command buffer.
7324        // Use the exact signature to avoid matching forward_prefill/forward_profile.
7325        let forward_start = model_rs
7326            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7327            .unwrap();
7328        let forward_body = &model_rs[forward_start..];
7329        // End at the next pub/private method
7330        let forward_end = forward_body
7331            .find("\n    pub fn forward_profile")
7332            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7333            .or_else(|| forward_body.find("\n    fn dispatch_"))
7334            .unwrap_or(forward_body.len());
7335        let forward_code = &forward_body[..forward_end];
7336
7337        // Should have exactly one new_command_buffer call
7338        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7339        assert_eq!(
7340            cmd_buf_count, 1,
7341            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7342        );
7343
7344        // Should have exactly one commit call
7345        let commit_count = forward_code.matches("cmd.commit()").count();
7346        assert_eq!(
7347            commit_count, 1,
7348            "forward() should commit exactly once, found {commit_count}"
7349        );
7350
7351        // Should wait: once for cmd + possibly once for prev_cmd drain
7352        let wait_count = forward_code.matches("wait_until_completed()").count();
7353        assert!(
7354            wait_count >= 1 && wait_count <= 2,
7355            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7356        );
7357    }
7358
7359    #[test]
7360    fn generated_model_has_preallocated_working_buffers() {
7361        let config = minimal_config();
7362        let model_rs = generate_model_rs(&config).unwrap();
7363
7364        for buf_name in &[
7365            "normed_buf",
7366            "qkv_buf",
7367            "attn_out_buf",
7368            "attn_proj_buf",
7369            "gate_up_buf",
7370            "ffn_hidden_buf",
7371            "ffn_out_buf",
7372            "add_tmp_buf",
7373        ] {
7374            assert!(
7375                model_rs.contains(&format!("{buf_name}: Buffer")),
7376                "MetalModel should have pre-allocated {buf_name} field"
7377            );
7378        }
7379    }
7380
7381    #[test]
7382    fn generated_dispatch_helpers_take_compute_encoder_ref() {
7383        let config = minimal_config();
7384        let model_rs = generate_model_rs(&config).unwrap();
7385
7386        for method in &[
7387            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7388            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7389            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7390            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7391            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7392            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7393            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7394            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7395            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7396            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7397            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7398            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7399            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7400        ] {
7401            assert!(
7402                model_rs.contains(method),
7403                "model.rs should contain dispatch helper: {method}"
7404            );
7405        }
7406    }
7407
7408    #[test]
7409    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7410        let config = minimal_config();
7411        let model_rs = generate_model_rs(&config).unwrap();
7412
7413        // Find dispatch helpers section and check none create their own encoders
7414        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7415        let helpers_code = &model_rs[helpers_start..];
7416
7417        // None of the dispatch_ helpers should call new_command_buffer
7418        assert!(
7419            !helpers_code.contains("self.queue.new_command_buffer()"),
7420            "dispatch helpers should not create their own command buffers"
7421        );
7422
7423        // None should create their own compute encoders
7424        assert!(
7425            !helpers_code.contains("new_compute_command_encoder()"),
7426            "dispatch helpers should not create their own compute encoders"
7427        );
7428
7429        // None should call end_encoding
7430        assert!(
7431            !helpers_code.contains("end_encoding()"),
7432            "dispatch helpers should not call end_encoding"
7433        );
7434
7435        // None should call commit or wait
7436        assert!(
7437            !helpers_code.contains(".commit()"),
7438            "dispatch helpers should not commit command buffers"
7439        );
7440        assert!(
7441            !helpers_code.contains("wait_until_completed"),
7442            "dispatch helpers should not wait on command buffers"
7443        );
7444    }
7445
7446    #[test]
7447    fn generated_forward_batches_compute_encoders() {
7448        let config = minimal_config();
7449        let model_rs = generate_model_rs(&config).unwrap();
7450
7451        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
7452        let forward_start = model_rs
7453            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7454            .unwrap();
7455        let forward_body = &model_rs[forward_start..];
7456        let forward_end = forward_body
7457            .find("\n    pub fn forward_profile")
7458            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7459            .or_else(|| forward_body.find("\n    fn dispatch_"))
7460            .unwrap_or(forward_body.len());
7461        let forward_code = &forward_body[..forward_end];
7462
7463        // Forward should not allocate new buffers
7464        assert!(
7465            !forward_code.contains("device.new_buffer"),
7466            "forward() should not allocate new buffers per call"
7467        );
7468
7469        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
7470        // Copy operations use compute copy kernels instead of blit encoders.
7471        let compute_encoder_count = forward_code
7472            .matches("new_compute_command_encoder()")
7473            .count();
7474        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7475
7476        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
7477        assert_eq!(
7478            compute_encoder_count, 1,
7479            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7480        );
7481        assert_eq!(
7482            blit_encoder_count, 0,
7483            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7484        );
7485    }
7486
7487    #[test]
7488    fn generated_forward_uses_add_inplace() {
7489        let config = minimal_config();
7490        let model_rs = generate_model_rs(&config).unwrap();
7491
7492        // Should use in-place add (no blit copy-back needed)
7493        assert!(
7494            model_rs.contains("dispatch_add_inplace"),
7495            "forward() should use dispatch_add_inplace for residual connections"
7496        );
7497        assert!(
7498            model_rs.contains("add_inplace_pipeline"),
7499            "MetalModel should have add_inplace_pipeline"
7500        );
7501    }
7502
7503    fn minimal_q8_config() -> ModelConfig {
7504        ModelConfig {
7505            architecture: Architecture::Llama,
7506            hidden_size: 64,
7507            intermediate_size: 128,
7508            num_layers: 2,
7509            num_attention_heads: 4,
7510            num_kv_heads: 4,
7511            head_dim: 16,
7512            vocab_size: 256,
7513            max_seq_len: 512,
7514            rms_norm_eps: 1e-5,
7515            rope_theta: 10000.0,
7516            dtype: DType::Q8_0,
7517            sliding_window_size: None,
7518            qkv_bias: false,
7519        }
7520    }
7521
7522    #[test]
7523    fn generated_shaders_contain_q8_kernel() {
7524        let shaders = generate_metal_shaders(&minimal_config());
7525
7526        assert!(
7527            shaders.contains("kernel void matmul_vec_q8"),
7528            "shaders should contain matmul_vec_q8 kernel"
7529        );
7530        assert!(
7531            shaders.contains("device const uchar* matrix"),
7532            "matmul_vec_q8 should accept raw Q8_0 bytes"
7533        );
7534        assert!(
7535            shaders.contains("packed_short4"),
7536            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7537        );
7538        assert!(
7539            shaders.contains("as_type<char2>"),
7540            "matmul_vec_q8 should bitcast short lanes to char2"
7541        );
7542        assert!(
7543            shaders.contains("device const half*"),
7544            "matmul_vec_q8 should read f16 scale via half pointer"
7545        );
7546    }
7547
7548    #[test]
7549    fn generated_model_uses_fused_qkv_projections() {
7550        let config = minimal_config();
7551        let model_rs = generate_model_rs(&config).unwrap();
7552
7553        // Should have fused QKV weight in layer buffers
7554        assert!(
7555            model_rs.contains("qkv_weight: Buffer"),
7556            "LayerBuffers should have fused qkv_weight field"
7557        );
7558        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
7559        assert!(
7560            !model_rs.contains("    q_weight: Buffer"),
7561            "LayerBuffers should not have separate q_weight field"
7562        );
7563        assert!(
7564            !model_rs.contains("    k_weight: Buffer"),
7565            "LayerBuffers should not have separate k_weight field"
7566        );
7567        assert!(
7568            !model_rs.contains("    v_weight: Buffer"),
7569            "LayerBuffers should not have separate v_weight field"
7570        );
7571
7572        // Should have fused gate_up_weight
7573        assert!(
7574            model_rs.contains("gate_up_weight: Buffer"),
7575            "LayerBuffers should have fused gate_up_weight field"
7576        );
7577        // Should NOT have separate gate/up weight fields
7578        assert!(
7579            !model_rs.contains("    gate_weight: Buffer"),
7580            "LayerBuffers should not have separate gate_weight field"
7581        );
7582        assert!(
7583            !model_rs.contains("    up_weight: Buffer"),
7584            "LayerBuffers should not have separate up_weight field"
7585        );
7586
7587        // Should have fused working buffers
7588        assert!(
7589            model_rs.contains("qkv_buf: Buffer"),
7590            "MetalModel should have fused qkv_buf"
7591        );
7592        assert!(
7593            model_rs.contains("gate_up_buf: Buffer"),
7594            "MetalModel should have fused gate_up_buf"
7595        );
7596
7597        // Forward pass should use fused dispatch
7598        assert!(
7599            model_rs.contains("dispatch_silu_mul_fused"),
7600            "forward pass should use dispatch_silu_mul_fused"
7601        );
7602        assert!(
7603            model_rs.contains("dispatch_rope_offset"),
7604            "forward pass should use dispatch_rope_offset for fused QKV"
7605        );
7606        assert!(
7607            model_rs.contains("dispatch_attention_offset"),
7608            "forward pass should use dispatch_attention_offset for fused QKV"
7609        );
7610    }
7611
7612    #[test]
7613    fn q8_model_has_matmul_q8_pipeline() {
7614        let config = minimal_q8_config();
7615        let model_rs = generate_model_rs(&config).unwrap();
7616
7617        assert!(
7618            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7619            "MetalModel should have matmul_q8_pipeline field"
7620        );
7621        assert!(
7622            model_rs.contains("matmul_q8_pipeline,"),
7623            "MetalModel Self should include matmul_q8_pipeline"
7624        );
7625    }
7626
7627    #[test]
7628    fn q8_model_uses_dispatch_matmul_q8() {
7629        let config = minimal_q8_config();
7630        let model_rs = generate_model_rs(&config).unwrap();
7631
7632        assert!(
7633            model_rs.contains("dispatch_matmul_q8"),
7634            "Q8_0 model should use dispatch_matmul_q8 for projections"
7635        );
7636        assert!(
7637            model_rs.contains("fn dispatch_matmul_q8"),
7638            "model.rs should define dispatch_matmul_q8 method"
7639        );
7640    }
7641
7642    #[test]
7643    fn q8_model_loads_raw_bytes_not_dequantized() {
7644        let config = minimal_q8_config();
7645        let model_rs = generate_model_rs(&config).unwrap();
7646
7647        // Should NOT contain dequantization code
7648        assert!(
7649            !model_rs.contains("f16_to_f32"),
7650            "Q8_0 model should not dequantize weights to f32"
7651        );
7652        assert!(
7653            !model_rs.contains("f32_data"),
7654            "Q8_0 model should not create f32 weight data"
7655        );
7656
7657        // Should load raw Q8_0 bytes directly
7658        assert!(
7659            model_rs.contains("total_raw as u64"),
7660            "Q8_0 model should load raw bytes into Metal buffer"
7661        );
7662    }
7663
7664    #[test]
7665    fn q8_model_norms_stay_f32() {
7666        let config = minimal_q8_config();
7667        let model_rs = generate_model_rs(&config).unwrap();
7668
7669        // Norm weights should still use f32 buffers
7670        assert!(
7671            model_rs.contains("let attn_norm = next_f32_buffer"),
7672            "attn_norm should use f32 buffer even for Q8_0 models"
7673        );
7674        assert!(
7675            model_rs.contains("let ffn_norm = next_f32_buffer"),
7676            "ffn_norm should use f32 buffer even for Q8_0 models"
7677        );
7678        assert!(
7679            model_rs.contains("let norm_buf = next_f32_buffer"),
7680            "final norm should use f32 buffer even for Q8_0 models"
7681        );
7682    }
7683
7684    #[test]
7685    fn q8_model_uses_fused_weight_loading() {
7686        let config = minimal_q8_config();
7687        let model_rs = generate_model_rs(&config).unwrap();
7688
7689        // Should use fused Q8 buffer loading for QKV
7690        assert!(
7691            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7692            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7693        );
7694        // Should use fused Q8 buffer loading for gate+up
7695        assert!(
7696            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7697            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7698        );
7699        // Should still use regular q8 buffer for individual weights
7700        assert!(
7701            model_rs.contains("let o_weight = next_q8_buffer"),
7702            "Q8_0 model should use next_q8_buffer for O weight"
7703        );
7704        assert!(
7705            model_rs.contains("let down_weight = next_q8_buffer"),
7706            "Q8_0 model should use next_q8_buffer for down weight"
7707        );
7708    }
7709
7710    #[test]
7711    fn f32_model_does_not_use_q8_dispatch() {
7712        let config = minimal_config();
7713        let model_rs = generate_model_rs(&config).unwrap();
7714
7715        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
7716        let forward_start = model_rs
7717            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7718            .unwrap();
7719        let forward_body = &model_rs[forward_start..];
7720        let forward_end = forward_body
7721            .find("\n    fn dispatch_")
7722            .unwrap_or(forward_body.len());
7723        let forward_code = &forward_body[..forward_end];
7724
7725        assert!(
7726            !forward_code.contains("dispatch_matmul_q8"),
7727            "f32 model forward should not use dispatch_matmul_q8"
7728        );
7729    }
7730
7731    #[test]
7732    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7733        let config = minimal_q8_config();
7734        let model_rs = generate_model_rs(&config).unwrap();
7735
7736        assert!(
7737            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7738            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7739        );
7740    }
7741
7742    #[test]
7743    fn generated_model_has_double_buffered_prefill() {
7744        let config = minimal_config();
7745        let model_rs = generate_model_rs(&config).unwrap();
7746
7747        // MetalModel should have prev_cmd field for double-buffered prefill
7748        assert!(
7749            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7750            "MetalModel should have prev_cmd field for double-buffered prefill"
7751        );
7752
7753        // Should have forward_prefill method
7754        assert!(
7755            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7756            "MetalModel should have forward_prefill method"
7757        );
7758
7759        // forward() should drain prev_cmd at the start
7760        assert!(
7761            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7762            "forward() should drain prev_cmd from previous prefill"
7763        );
7764    }
7765
7766    #[test]
7767    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7768        let config = minimal_config();
7769        let main_rs = generate_main_rs("test-model", &config).unwrap();
7770
7771        assert!(
7772            main_rs.contains("forward_prefill"),
7773            "main.rs should use forward_prefill for intermediate prompt tokens"
7774        );
7775        assert!(
7776            main_rs.contains("double-buffered"),
7777            "main.rs should document double-buffered prefill"
7778        );
7779    }
7780
7781    #[test]
7782    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7783        let shaders = generate_metal_shaders(&minimal_config());
7784
7785        assert!(
7786            shaders.contains("packed_short4"),
7787            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7788        );
7789        assert!(
7790            shaders.contains("d0[0]"),
7791            "matmul_vec_q8 should index the wide pointer for row 0"
7792        );
7793        assert!(
7794            shaders.contains("as_type<char2>"),
7795            "matmul_vec_q8 should bitcast short lanes to char2"
7796        );
7797        assert!(
7798            shaders.contains("dot("),
7799            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7800        );
7801    }
7802
7803    // ── Q4_0 tests ──────────────────────────────────────────────────────
7804
7805    fn minimal_q4_config() -> ModelConfig {
7806        ModelConfig {
7807            architecture: Architecture::Llama,
7808            hidden_size: 64,
7809            intermediate_size: 128,
7810            num_layers: 2,
7811            num_attention_heads: 4,
7812            num_kv_heads: 4,
7813            head_dim: 16,
7814            vocab_size: 256,
7815            max_seq_len: 512,
7816            rms_norm_eps: 1e-5,
7817            rope_theta: 10000.0,
7818            dtype: DType::Q4_0,
7819            sliding_window_size: None,
7820            qkv_bias: false,
7821        }
7822    }
7823
7824    #[test]
7825    fn generated_shaders_contain_q4_kernel() {
7826        let shaders = generate_metal_shaders(&minimal_config());
7827
7828        assert!(
7829            shaders.contains("kernel void matmul_vec_q4"),
7830            "shaders should contain matmul_vec_q4 kernel"
7831        );
7832        assert!(
7833            shaders.contains("Q4_ROWS_PER_TG"),
7834            "shaders should define Q4_ROWS_PER_TG constant"
7835        );
7836        assert!(
7837            shaders.contains("Q4_ROWS_PER_SG"),
7838            "shaders should define Q4_ROWS_PER_SG constant"
7839        );
7840    }
7841
7842    #[test]
7843    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7844        let shaders = generate_metal_shaders(&minimal_config());
7845
7846        // Q4_0 kernel should use uchar4 for packed byte loads
7847        assert!(
7848            shaders.contains("uchar4"),
7849            "matmul_vec_q4 should use uchar4 for packed byte loads"
7850        );
7851        // Should unpack nibbles with &0xF and >>4
7852        assert!(
7853            shaders.contains("&0xF"),
7854            "matmul_vec_q4 should extract low nibble with &0xF"
7855        );
7856        assert!(
7857            shaders.contains(">>4"),
7858            "matmul_vec_q4 should extract high nibble with >>4"
7859        );
7860        // Should subtract 8 to convert unsigned to signed
7861        assert!(
7862            shaders.contains("-8)"),
7863            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7864        );
7865        // Should use 18-byte block size
7866        assert!(
7867            shaders.contains("blk * 18"),
7868            "matmul_vec_q4 should use 18-byte block stride"
7869        );
7870    }
7871
7872    #[test]
7873    fn q4_model_has_matmul_q4_pipeline() {
7874        let config = minimal_q4_config();
7875        let model_rs = generate_model_rs(&config).unwrap();
7876
7877        assert!(
7878            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7879            "MetalModel should have matmul_q4_pipeline field"
7880        );
7881        assert!(
7882            model_rs.contains("matmul_q4_pipeline,"),
7883            "MetalModel Self should include matmul_q4_pipeline"
7884        );
7885    }
7886
7887    #[test]
7888    fn q4_model_uses_dispatch_matmul_q4() {
7889        let config = minimal_q4_config();
7890        let model_rs = generate_model_rs(&config).unwrap();
7891
7892        assert!(
7893            model_rs.contains("dispatch_matmul_q4"),
7894            "Q4_0 model should use dispatch_matmul_q4 for projections"
7895        );
7896        assert!(
7897            model_rs.contains("fn dispatch_matmul_q4"),
7898            "model.rs should define dispatch_matmul_q4 method"
7899        );
7900    }
7901
7902    #[test]
7903    fn q4_model_loads_raw_bytes_not_dequantized() {
7904        let config = minimal_q4_config();
7905        let model_rs = generate_model_rs(&config).unwrap();
7906
7907        // Should NOT contain dequantization code
7908        assert!(
7909            !model_rs.contains("f16_to_f32"),
7910            "Q4_0 model should not dequantize weights to f32"
7911        );
7912
7913        // Should load raw Q4_0 bytes directly
7914        assert!(
7915            model_rs.contains("total_raw as u64"),
7916            "Q4_0 model should load raw bytes into Metal buffer"
7917        );
7918    }
7919
7920    #[test]
7921    fn q4_model_norms_stay_f32() {
7922        let config = minimal_q4_config();
7923        let model_rs = generate_model_rs(&config).unwrap();
7924
7925        assert!(
7926            model_rs.contains("let attn_norm = next_f32_buffer"),
7927            "attn_norm should use f32 buffer even for Q4_0 models"
7928        );
7929        assert!(
7930            model_rs.contains("let ffn_norm = next_f32_buffer"),
7931            "ffn_norm should use f32 buffer even for Q4_0 models"
7932        );
7933        assert!(
7934            model_rs.contains("let norm_buf = next_f32_buffer"),
7935            "final norm should use f32 buffer even for Q4_0 models"
7936        );
7937    }
7938
7939    #[test]
7940    fn q4_model_uses_fused_weight_loading() {
7941        let config = minimal_q4_config();
7942        let model_rs = generate_model_rs(&config).unwrap();
7943
7944        assert!(
7945            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7946            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7947        );
7948        assert!(
7949            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7950            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7951        );
7952        assert!(
7953            model_rs.contains("let o_weight = next_q4_buffer"),
7954            "Q4_0 model should use next_q4_buffer for O weight"
7955        );
7956        assert!(
7957            model_rs.contains("let down_weight = next_q4_buffer"),
7958            "Q4_0 model should use next_q4_buffer for down weight"
7959        );
7960    }
7961
7962    #[test]
7963    fn attention_flash_batch_kernel_exists() {
7964        // The flash kernel is still wired into the library (pipeline, kernel
7965        // source).  Dispatch is currently routed to the legacy path pending a
7966        // fix for a numerical issue discovered after the prompt-chunking fix.
7967        let config = minimal_config();
7968        let model_rs = generate_model_rs(&config).unwrap();
7969        let shaders = generate_metal_shaders(&config);
7970
7971        assert!(
7972            shaders.contains("kernel void attention_flash_batch"),
7973            "shaders.metal must still contain the attention_flash_batch kernel"
7974        );
7975        assert!(
7976            shaders.contains("FLASH_K_TILE"),
7977            "flash kernel must tile K/V with a FLASH_K_TILE constant"
7978        );
7979        assert!(
7980            model_rs.contains("attention_flash_batch_pipeline"),
7981            "MetalModel must register the flash attention pipeline"
7982        );
7983    }
7984
7985    #[test]
7986    fn decode_uses_fused_rope_and_kv_copy() {
7987        // v0.7.2: single-token decode reuses `rope_qk_batch` (M=1) and
7988        // `copy_kv_both_batch` (M=1) instead of calling the separate
7989        // rope_offset + copy_from_offset_f16 pairs, saving 2 dispatches +
7990        // barriers per layer.
7991        let config = minimal_config();
7992        let model_rs = generate_model_rs(&config).unwrap();
7993
7994        // forward() and forward_profile() each have a decode loop.
7995        // Locate `pub fn forward(` and `pub fn forward_profile(` and verify
7996        // the fused dispatches appear between the function header and the
7997        // start of forward_prefill (which also uses these fused helpers).
7998        let forward_start = model_rs.find("pub fn forward(").expect("forward missing");
7999        let forward_end = model_rs[forward_start..]
8000            .find("pub fn forward_prefill(")
8001            .expect("forward_prefill missing");
8002        let forward_body = &model_rs[forward_start..forward_start + forward_end];
8003
8004        assert!(
8005            forward_body.contains("dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1"),
8006            "single-token forward must use fused rope_qk_batch with M=1"
8007        );
8008        assert!(
8009            forward_body.contains("dispatch_copy_kv_both_batch(&enc, &self.qkv_buf,"),
8010            "single-token forward must use fused copy_kv_both_batch"
8011        );
8012        assert!(
8013            !forward_body.contains("dispatch_copy_from_offset_f16"),
8014            "decode path should no longer call the per-K/per-V copy"
8015        );
8016    }
8017
8018    #[test]
8019    fn kv_cache_stored_as_f16() {
8020        // v0.7.1: KV cache is f16 (2 bytes/element) instead of f32 to halve
8021        // attention memory bandwidth and KV RAM footprint.  All attention
8022        // kernels must read K/V as `device const half*`; copy kernels must
8023        // write half; the f32->f16 copy kernel must be wired for single-
8024        // token decode KV writes.
8025        let config = minimal_config();
8026        let shaders = generate_metal_shaders(&config);
8027        let model_rs = generate_model_rs(&config).unwrap();
8028
8029        for kernel in [
8030            "kernel void attention(",
8031            "kernel void attention_batch(",
8032            "kernel void attention_flash_batch(",
8033            "kernel void attention_mma_flash_batch(",
8034        ] {
8035            let start = shaders
8036                .find(kernel)
8037                .unwrap_or_else(|| panic!("kernel {kernel} missing"));
8038            // Slice through the signature block — scan for "){" which closes
8039            // the parameter list at the top of every kernel's body.
8040            let sig_end = shaders[start..]
8041                .find("){")
8042                .unwrap_or_else(|| shaders[start..].find(") {").unwrap());
8043            let sig = &shaders[start..start + sig_end];
8044            assert!(
8045                sig.contains("device const half*")
8046                    && sig.contains("k_cache")
8047                    && sig.contains("v_cache"),
8048                "{kernel} must read k_cache/v_cache as `device const half*`"
8049            );
8050            assert!(
8051                !sig.contains("device const float* k_cache")
8052                    && !sig.contains("device const float*  k_cache"),
8053                "{kernel} still reads k_cache as float"
8054            );
8055        }
8056
8057        assert!(
8058            shaders.contains("kernel void copy_f32_to_f16_offset"),
8059            "f32->f16 copy kernel must be present for single-token decode KV writes"
8060        );
8061        assert!(
8062            model_rs.contains("dispatch_copy_from_offset_f16"),
8063            "single-token decode must dispatch the f32->f16 KV copy"
8064        );
8065    }
8066
8067    #[test]
8068    fn decode_attention_uses_half4_vectorized_loads() {
8069        // v0.7.3: vectorized K/V via half4 + float4.
8070        // v0.7.4: V-step restructured — one simdgroup per d4 chunk, 32 lanes
8071        // partition seq_len with simd_sum reduction. Fixes head_dim=64 under-
8072        // utilization (16 productive threads → 256 productive threads).
8073        let config = minimal_config();
8074        let shaders = generate_metal_shaders(&config);
8075
8076        let start = shaders
8077            .find("kernel void attention(")
8078            .expect("decode attention kernel missing");
8079        // Slice just the decode kernel body — stop at the next kernel.
8080        let end_rel = shaders[start + 1..]
8081            .find("kernel void ")
8082            .expect("next kernel missing");
8083        let body = &shaders[start..start + 1 + end_rel];
8084
8085        assert!(
8086            body.contains("device const half4*"),
8087            "decode attention must half4-load K/V"
8088        );
8089        assert!(
8090            body.contains("device const float4*"),
8091            "decode attention must float4-load Q"
8092        );
8093        assert!(
8094            body.contains("device float4*"),
8095            "decode attention must float4-store output"
8096        );
8097        assert!(
8098            body.contains("head_dim4"),
8099            "decode attention must iterate head_dim in chunks of 4"
8100        );
8101        // v0.7.4: V-step uses simd_sum per component to reduce across 32 lanes.
8102        assert!(
8103            body.contains("simd_sum(partial.x)")
8104                && body.contains("simd_sum(partial.y)")
8105                && body.contains("simd_sum(partial.z)")
8106                && body.contains("simd_sum(partial.w)"),
8107            "decode attention V-step must reduce float4 partials via simd_sum"
8108        );
8109        // And the V-step outer loop must iterate d4 across simdgroups (simd_id), not threads (tid).
8110        assert!(
8111            body.contains("for (uint d4 = simd_id; d4 < head_dim4; d4 += 8)"),
8112            "decode attention V-step must partition d4 across simdgroups"
8113        );
8114    }
8115
8116    #[test]
8117    fn attention_mma_flash_batch_kernel_wired() {
8118        // MMA-accelerated flash attention (issue #212).  Default-on in v0.7.0
8119        // when HEAD_DIM ≤ 128 and num_tokens ≥ 8.  FORGE_MMA_ATTN=0 opts out.
8120        let config = minimal_config();
8121        let model_rs = generate_model_rs(&config).unwrap();
8122        let shaders = generate_metal_shaders(&config);
8123
8124        assert!(
8125            shaders.contains("kernel void attention_mma_flash_batch"),
8126            "shaders.metal must contain the MMA flash kernel"
8127        );
8128        assert!(
8129            shaders.contains("FLASH_MMA_Q_BLOCK"),
8130            "MMA flash kernel must define Q_BLOCK tiling constant"
8131        );
8132        assert!(
8133            shaders.contains("simdgroup_multiply_accumulate"),
8134            "MMA flash kernel must use hardware MMA"
8135        );
8136        assert!(
8137            model_rs.contains("attention_mma_flash_batch_pipeline"),
8138            "MetalModel must register the MMA flash pipeline"
8139        );
8140        assert!(
8141            model_rs.contains("mma_opt_out"),
8142            "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
8143        );
8144        assert!(
8145            model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
8146            "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
8147        );
8148    }
8149
8150    #[test]
8151    fn forward_prefill_batch_chunks_by_max_batch_size() {
8152        // Regression: prior to v0.6.4 forward_prefill_batch truncated prompts
8153        // longer than MAX_BATCH_SIZE (512) tokens via `.min(MAX_BATCH_SIZE)`,
8154        // silently dropping the middle of long prompts.  Must now loop over
8155        // MAX_BATCH_SIZE-sized chunks and carry KV-cache state across them.
8156        let config = minimal_config();
8157        let model_rs = generate_model_rs(&config).unwrap();
8158        assert!(
8159            model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
8160            "forward_prefill_batch must chunk long prompts"
8161        );
8162        assert!(
8163            !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
8164            "the old truncation path must be gone"
8165        );
8166    }
8167
8168    #[test]
8169    fn qwen2_qkv_bias_wired_through_metal_codegen() {
8170        // Issue #210: the pre-v0.6.2 Metal codegen had zero handling for
8171        // qkv_bias.  Verify that a Qwen2-style config emits the bias buffer,
8172        // loader, pipeline, and dispatch call in the expected places.
8173        let config = ModelConfig {
8174            architecture: Architecture::Qwen2,
8175            qkv_bias: true,
8176            ..minimal_config()
8177        };
8178        let model_rs = generate_model_rs(&config).unwrap();
8179
8180        assert!(
8181            model_rs.contains("qkv_bias: Buffer"),
8182            "Qwen2 LayerBuffers must declare qkv_bias field"
8183        );
8184        assert!(
8185            model_rs.contains("let qkv_bias = next_f32_buffer"),
8186            "Qwen2 layer init must load the bias from the weight blob"
8187        );
8188        assert!(
8189            model_rs.contains("add_bias_batch_pipeline"),
8190            "Qwen2 model struct must include the add_bias_batch_pipeline"
8191        );
8192        assert!(
8193            model_rs.contains("fn dispatch_add_bias_batch"),
8194            "Qwen2 codegen must emit dispatch_add_bias_batch helper"
8195        );
8196        assert!(
8197            model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
8198            "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
8199        );
8200        assert!(
8201            model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
8202            "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
8203        );
8204
8205        // The add_bias_batch MSL kernel must be in the shader source.
8206        let shaders = generate_metal_shaders(&config);
8207        assert!(
8208            shaders.contains("kernel void add_bias_batch"),
8209            "shaders.metal must contain the add_bias_batch kernel"
8210        );
8211    }
8212
8213    #[test]
8214    fn llama_does_not_emit_qkv_bias_machinery() {
8215        // Negative test: non-Qwen2 models must NOT carry the bias dispatch,
8216        // buffer, or pipeline — keeps generated code lean for Llama/Phi/etc.
8217        let config = minimal_config();
8218        assert!(!config.qkv_bias);
8219        let model_rs = generate_model_rs(&config).unwrap();
8220        assert!(
8221            !model_rs.contains("qkv_bias: Buffer"),
8222            "Llama must not have qkv_bias field"
8223        );
8224        assert!(
8225            !model_rs.contains("add_bias_batch_pipeline"),
8226            "Llama must not pull in add_bias_batch_pipeline"
8227        );
8228        assert!(
8229            !model_rs.contains("dispatch_add_bias_batch"),
8230            "Llama must not call dispatch_add_bias_batch"
8231        );
8232    }
8233
8234    #[test]
8235    fn q4_dispatch_helper_takes_compute_encoder_ref() {
8236        let config = minimal_q4_config();
8237        let model_rs = generate_model_rs(&config).unwrap();
8238
8239        assert!(
8240            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8241            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8242        );
8243    }
8244
8245    #[test]
8246    fn f32_model_does_not_use_q4_dispatch() {
8247        let config = minimal_config();
8248        let model_rs = generate_model_rs(&config).unwrap();
8249
8250        let forward_start = model_rs
8251            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8252            .unwrap();
8253        let forward_body = &model_rs[forward_start..];
8254        let forward_end = forward_body
8255            .find("\n    fn dispatch_")
8256            .unwrap_or(forward_body.len());
8257        let forward_code = &forward_body[..forward_end];
8258
8259        assert!(
8260            !forward_code.contains("dispatch_matmul_q4"),
8261            "f32 model forward should not use dispatch_matmul_q4"
8262        );
8263    }
8264
8265    #[test]
8266    fn q4_model_lm_head_uses_q4_buffer() {
8267        let config = minimal_q4_config();
8268        let model_rs = generate_model_rs(&config).unwrap();
8269
8270        assert!(
8271            model_rs.contains("let lm_head_buf = next_q4_buffer"),
8272            "Q4_0 model should use next_q4_buffer for lm_head"
8273        );
8274    }
8275
8276    #[test]
8277    fn vec_tile_size_matches_model_dimensions() {
8278        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
8279        let small = minimal_config();
8280        let shaders_small = generate_metal_shaders(&small);
8281        assert!(
8282            shaders_small.contains("vec_tile[128]"),
8283            "vec_tile should be sized to max(hidden, intermediate) = 128"
8284        );
8285
8286        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
8287        let mut large = minimal_config();
8288        large.hidden_size = 2048;
8289        large.intermediate_size = 8192;
8290        let shaders_large = generate_metal_shaders(&large);
8291        assert!(
8292            shaders_large.contains("vec_tile[8192]"),
8293            "vec_tile should be 8192 for models with intermediate=8192"
8294        );
8295        assert!(
8296            !shaders_large.contains("vec_tile[4096]"),
8297            "vec_tile should NOT be hardcoded to 4096"
8298        );
8299    }
8300
8301    #[test]
8302    fn generated_cargo_toml_has_server_deps() {
8303        let toml = generate_cargo_toml("my-model");
8304        assert!(
8305            toml.contains("tiny_http"),
8306            "Cargo.toml should depend on tiny_http"
8307        );
8308        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8309        assert!(
8310            toml.contains("serde_json"),
8311            "Cargo.toml should depend on serde_json"
8312        );
8313    }
8314
8315    #[test]
8316    fn generated_main_rs_has_serve_mode() {
8317        let config = minimal_config();
8318        let main_rs = generate_main_rs("test-model", &config).unwrap();
8319
8320        assert!(
8321            main_rs.contains("--serve"),
8322            "main.rs should parse --serve flag"
8323        );
8324        assert!(
8325            main_rs.contains("--port"),
8326            "main.rs should parse --port flag"
8327        );
8328        assert!(
8329            main_rs.contains("fn serve("),
8330            "main.rs should define serve function"
8331        );
8332        assert!(
8333            main_rs.contains("tiny_http::Server::http"),
8334            "main.rs should create tiny_http server"
8335        );
8336    }
8337
8338    #[test]
8339    fn generated_main_rs_has_chat_completions_endpoint() {
8340        let config = minimal_config();
8341        let main_rs = generate_main_rs("test-model", &config).unwrap();
8342
8343        assert!(
8344            main_rs.contains("/v1/chat/completions"),
8345            "main.rs should handle /v1/chat/completions endpoint"
8346        );
8347        assert!(
8348            main_rs.contains("/v1/models"),
8349            "main.rs should handle /v1/models endpoint"
8350        );
8351        assert!(
8352            main_rs.contains("/health"),
8353            "main.rs should handle /health endpoint"
8354        );
8355    }
8356
8357    #[test]
8358    fn generated_main_rs_has_sse_streaming() {
8359        let config = minimal_config();
8360        let main_rs = generate_main_rs("test-model", &config).unwrap();
8361
8362        assert!(
8363            main_rs.contains("text/event-stream"),
8364            "main.rs should set SSE content type for streaming"
8365        );
8366        assert!(
8367            main_rs.contains("chat.completion.chunk"),
8368            "main.rs should emit SSE chunks"
8369        );
8370        assert!(
8371            main_rs.contains("[DONE]"),
8372            "main.rs should emit [DONE] sentinel"
8373        );
8374    }
8375
8376    #[test]
8377    fn generated_main_rs_has_chat_message_formatting() {
8378        let config = minimal_config();
8379        let main_rs = generate_main_rs("test-model", &config).unwrap();
8380
8381        assert!(
8382            main_rs.contains("fn format_chat_messages"),
8383            "main.rs should define format_chat_messages function"
8384        );
8385        assert!(
8386            main_rs.contains("<|im_start|>"),
8387            "main.rs should use ChatML format"
8388        );
8389        assert!(
8390            main_rs.contains("<|im_end|>"),
8391            "main.rs should use ChatML format"
8392        );
8393    }
8394
8395    #[test]
8396    fn generated_main_rs_has_request_types() {
8397        let config = minimal_config();
8398        let main_rs = generate_main_rs("test-model", &config).unwrap();
8399
8400        assert!(
8401            main_rs.contains("struct ChatRequest"),
8402            "main.rs should define ChatRequest struct"
8403        );
8404        assert!(
8405            main_rs.contains("struct ChatMessage"),
8406            "main.rs should define ChatMessage struct"
8407        );
8408        assert!(
8409            main_rs.contains("Deserialize"),
8410            "main.rs should derive Deserialize for request types"
8411        );
8412    }
8413
8414    #[test]
8415    fn generated_model_has_reset_method() {
8416        let config = minimal_config();
8417        let model_rs = generate_model_rs(&config).unwrap();
8418
8419        assert!(
8420            model_rs.contains("pub fn reset(&mut self)"),
8421            "model.rs should have a reset() method for multi-request serving"
8422        );
8423        assert!(
8424            model_rs.contains("self.pos = 0"),
8425            "reset() should reset position to 0"
8426        );
8427    }
8428
8429    #[test]
8430    fn generated_main_rs_cli_mode_still_works() {
8431        let config = minimal_config();
8432        let main_rs = generate_main_rs("test-model", &config).unwrap();
8433
8434        // CLI mode should still be functional
8435        assert!(
8436            main_rs.contains("fn cli_mode("),
8437            "main.rs should define cli_mode function"
8438        );
8439        assert!(
8440            main_rs.contains("model.forward"),
8441            "main.rs should call model.forward"
8442        );
8443        assert!(
8444            main_rs.contains("model.forward_prefill"),
8445            "main.rs should call model.forward_prefill"
8446        );
8447    }
8448
8449    // ── Batched prefill tests ──────────────────────────────────────────
8450
8451    #[test]
8452    fn generated_shaders_contain_batch_kernels() {
8453        let shaders = generate_metal_shaders(&minimal_config());
8454
8455        assert!(
8456            shaders.contains("kernel void matmul_vec_batch"),
8457            "shaders should contain matmul_vec_batch kernel"
8458        );
8459        assert!(
8460            shaders.contains("kernel void matmul_vec_q8_batch"),
8461            "shaders should contain matmul_vec_q8_batch kernel"
8462        );
8463        assert!(
8464            shaders.contains("kernel void matmul_q8_gemm_batch"),
8465            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8466        );
8467        assert!(
8468            shaders.contains("kernel void matmul_vec_q4_batch"),
8469            "shaders should contain matmul_vec_q4_batch kernel"
8470        );
8471        assert!(
8472            shaders.contains("kernel void rms_norm_batch"),
8473            "shaders should contain rms_norm_batch kernel"
8474        );
8475        assert!(
8476            shaders.contains("kernel void silu_mul_fused_batch"),
8477            "shaders should contain silu_mul_fused_batch kernel"
8478        );
8479        assert!(
8480            shaders.contains("kernel void add_inplace_batch"),
8481            "shaders should contain add_inplace_batch kernel"
8482        );
8483        assert!(
8484            shaders.contains("kernel void copy_embedding_batch"),
8485            "shaders should contain copy_embedding_batch kernel"
8486        );
8487    }
8488
8489    #[test]
8490    fn generated_model_has_batch_pipelines() {
8491        let config = minimal_config();
8492        let model_rs = generate_model_rs(&config).unwrap();
8493
8494        for pipeline in &[
8495            "matmul_batch_pipeline",
8496            "matmul_q8_batch_pipeline",
8497            "matmul_q8_gemm_batch_pipeline",
8498            "matmul_q4_batch_pipeline",
8499            "rms_norm_batch_pipeline",
8500            "rope_batch_pipeline",
8501            "silu_mul_fused_batch_pipeline",
8502            "add_inplace_batch_pipeline",
8503            "copy_embedding_batch_pipeline",
8504            "attention_batch_pipeline",
8505            "copy_kv_batch_pipeline",
8506        ] {
8507            assert!(
8508                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8509                "MetalModel should have {pipeline} field"
8510            );
8511        }
8512    }
8513
8514    #[test]
8515    fn generated_model_has_batch_buffers() {
8516        let config = minimal_config();
8517        let model_rs = generate_model_rs(&config).unwrap();
8518
8519        for buf in &[
8520            "batch_hidden_buf",
8521            "batch_residual_buf",
8522            "batch_qkv_buf",
8523            "batch_attn_out_buf",
8524            "batch_attn_proj_buf",
8525            "batch_gate_up_buf",
8526            "batch_ffn_hidden_buf",
8527            "batch_ffn_out_buf",
8528            "batch_tokens_buf",
8529            "batch_positions_buf",
8530        ] {
8531            assert!(
8532                model_rs.contains(&format!("{buf}: Buffer")),
8533                "MetalModel should have {buf} field"
8534            );
8535        }
8536    }
8537
8538    #[test]
8539    fn generated_model_has_forward_prefill_batch() {
8540        let config = minimal_config();
8541        let model_rs = generate_model_rs(&config).unwrap();
8542
8543        assert!(
8544            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8545            "MetalModel should have forward_prefill_batch method"
8546        );
8547
8548        // forward_prefill should delegate to forward_prefill_batch
8549        assert!(
8550            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8551            "forward_prefill should delegate to forward_prefill_batch"
8552        );
8553    }
8554
8555    #[test]
8556    fn generated_model_has_max_batch_size_constant() {
8557        let config = minimal_config();
8558        let model_rs = generate_model_rs(&config).unwrap();
8559
8560        assert!(
8561            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8562            "model.rs should define MAX_BATCH_SIZE constant"
8563        );
8564    }
8565
8566    #[test]
8567    fn forward_prefill_batch_uses_batch_dispatch() {
8568        let config = minimal_config();
8569        let model_rs = generate_model_rs(&config).unwrap();
8570
8571        let batch_start = model_rs
8572            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8573            .unwrap();
8574        let batch_body = &model_rs[batch_start..];
8575        let batch_end = batch_body
8576            .find("\n    pub fn reset")
8577            .unwrap_or(batch_body.len());
8578        let batch_code = &batch_body[..batch_end];
8579
8580        // Should use batched dispatch methods
8581        assert!(
8582            batch_code.contains("dispatch_rms_norm_batch"),
8583            "forward_prefill_batch should use dispatch_rms_norm_batch"
8584        );
8585        assert!(
8586            batch_code.contains("dispatch_copy_embedding_batch"),
8587            "forward_prefill_batch should use dispatch_copy_embedding_batch"
8588        );
8589        assert!(
8590            batch_code.contains("dispatch_silu_mul_fused_batch"),
8591            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8592        );
8593        // Should use batched causal attention dispatch
8594        assert!(
8595            batch_code.contains("dispatch_attention_batch"),
8596            "forward_prefill_batch should use dispatch_attention_batch"
8597        );
8598        // Should use fused KV cache copy (both K and V in one dispatch)
8599        assert!(
8600            batch_code.contains("dispatch_copy_kv_both_batch"),
8601            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8602        );
8603        // Should use fused RoPE Q+K dispatch
8604        assert!(
8605            batch_code.contains("dispatch_rope_qk_batch"),
8606            "forward_prefill_batch should use dispatch_rope_qk_batch"
8607        );
8608    }
8609
8610    #[test]
8611    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8612        let config = minimal_q8_config();
8613        let model_rs = generate_model_rs(&config).unwrap();
8614
8615        let batch_start = model_rs
8616            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8617            .unwrap();
8618        let batch_body = &model_rs[batch_start..];
8619        let batch_end = batch_body
8620            .find("\n    pub fn reset")
8621            .unwrap_or(batch_body.len());
8622        let batch_code = &batch_body[..batch_end];
8623
8624        assert!(
8625            batch_code.contains("dispatch_matmul_q8_batch"),
8626            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8627        );
8628    }
8629
8630    #[test]
8631    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8632        let config = minimal_q4_config();
8633        let model_rs = generate_model_rs(&config).unwrap();
8634
8635        let batch_start = model_rs
8636            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8637            .unwrap();
8638        let batch_body = &model_rs[batch_start..];
8639        let batch_end = batch_body
8640            .find("\n    pub fn reset")
8641            .unwrap_or(batch_body.len());
8642        let batch_code = &batch_body[..batch_end];
8643
8644        assert!(
8645            batch_code.contains("dispatch_matmul_q4_batch"),
8646            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8647        );
8648    }
8649
8650    #[test]
8651    fn generated_main_rs_uses_batched_prefill() {
8652        let config = minimal_config();
8653        let main_rs = generate_main_rs("test-model", &config).unwrap();
8654
8655        assert!(
8656            main_rs.contains("forward_prefill_batch"),
8657            "main.rs should use forward_prefill_batch for prompt tokens"
8658        );
8659    }
8660
8661    #[test]
8662    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8663        let config = minimal_config();
8664        let model_rs = generate_model_rs(&config).unwrap();
8665
8666        let batch_start = model_rs
8667            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8668            .unwrap();
8669        let batch_body = &model_rs[batch_start..];
8670        let batch_end = batch_body
8671            .find("\n    pub fn reset")
8672            .unwrap_or(batch_body.len());
8673        let batch_code = &batch_body[..batch_end];
8674
8675        assert!(
8676            batch_code.contains("dispatch_matmul_batch"),
8677            "f32 forward_prefill_batch should use dispatch_matmul_batch"
8678        );
8679        // Should NOT use Q8 or Q4 batch dispatch
8680        assert!(
8681            !batch_code.contains("dispatch_matmul_q8_batch"),
8682            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8683        );
8684        assert!(
8685            !batch_code.contains("dispatch_matmul_q4_batch"),
8686            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8687        );
8688    }
8689
8690    #[test]
8691    fn forward_uses_cpu_embedding_lookup() {
8692        let config = minimal_config();
8693        let model_rs = generate_model_rs(&config).unwrap();
8694
8695        // Find just the forward() body (not forward_profile)
8696        let forward_start = model_rs
8697            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8698            .unwrap();
8699        let forward_body = &model_rs[forward_start..];
8700        let forward_end = forward_body
8701            .find("\n    pub fn forward_profile")
8702            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8703            .unwrap_or(forward_body.len());
8704        let forward_code = &forward_body[..forward_end];
8705
8706        // forward() should use CPU memcpy for embedding lookup (unified memory)
8707        assert!(
8708            forward_code.contains("embed_buf.contents()"),
8709            "forward() should access embed_buf via CPU unified memory for embedding lookup"
8710        );
8711        assert!(
8712            forward_code.contains("copy_nonoverlapping"),
8713            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8714        );
8715        // forward() should NOT use GPU dispatch for embedding
8716        assert!(
8717            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8718            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8719        );
8720    }
8721
8722    #[test]
8723    fn forward_profile_method_exists() {
8724        let config = minimal_config();
8725        let model_rs = generate_model_rs(&config).unwrap();
8726
8727        assert!(
8728            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8729            "MetalModel should have forward_profile() method"
8730        );
8731        // Profile method should print timing information
8732        assert!(
8733            model_rs.contains("[profile]"),
8734            "forward_profile() should print timing with [profile] prefix"
8735        );
8736        assert!(
8737            model_rs.contains("d_embed"),
8738            "forward_profile() should measure embedding time"
8739        );
8740        assert!(
8741            model_rs.contains("d_layers"),
8742            "forward_profile() should measure layer time"
8743        );
8744        assert!(
8745            model_rs.contains("d_logits"),
8746            "forward_profile() should measure logits time"
8747        );
8748    }
8749
8750    #[test]
8751    fn generated_cli_has_profile_flag() {
8752        let config = minimal_config();
8753        let main_rs = generate_main_rs("test-model", &config).unwrap();
8754
8755        assert!(
8756            main_rs.contains("--profile"),
8757            "CLI should support --profile flag"
8758        );
8759        assert!(
8760            main_rs.contains("forward_profile"),
8761            "CLI should call forward_profile when --profile is set"
8762        );
8763    }
8764
8765    #[test]
8766    fn generated_cli_has_thermal_yield() {
8767        let config = minimal_config();
8768        let main_rs = generate_main_rs("test-model", &config).unwrap();
8769
8770        assert!(
8771            main_rs.contains("yield_now()"),
8772            "CLI generation loop should include thread::yield_now() for thermal management"
8773        );
8774    }
8775
8776    // ── Real-world validation tests ──────────────────────────────────────
8777
8778    #[test]
8779    fn generated_forward_handles_single_token_prompt() {
8780        // With a single token (the first prompt token), forward() should work
8781        // at pos=0 where seq_len=1. The attention kernel must handle the case
8782        // where there is only one KV entry (no prefill context).
8783        let config = minimal_config();
8784        let model_rs = generate_model_rs(&config).unwrap();
8785
8786        // The forward function should accept any u32 token_id (no minimum pos guard)
8787        let forward_start = model_rs
8788            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8789            .expect("forward() must exist");
8790        let forward_body = &model_rs[forward_start..forward_start + 400];
8791
8792        // Should NOT require pos > 0 or seq_len > 1
8793        assert!(
8794            !forward_body.contains("assert!(self.pos > 0"),
8795            "forward() must accept pos=0 (first token with no prefill)"
8796        );
8797
8798        // The attention kernel should handle seq_len=1 via the pos field
8799        assert!(
8800            model_rs.contains("self.pos"),
8801            "forward() should use self.pos to track sequence position"
8802        );
8803    }
8804
8805    #[test]
8806    fn generated_reset_clears_kv_cache_position() {
8807        // After reset(), the model should be in a clean state. The pos field
8808        // must be 0 so new generation starts from scratch.
8809        let config = minimal_config();
8810        let model_rs = generate_model_rs(&config).unwrap();
8811
8812        let reset_start = model_rs
8813            .find("pub fn reset(&mut self)")
8814            .expect("reset() must exist");
8815        let reset_body = &model_rs[reset_start..reset_start + 200];
8816
8817        // Reset must zero the position counter
8818        assert!(
8819            reset_body.contains("self.pos = 0"),
8820            "reset() must set self.pos = 0"
8821        );
8822
8823        // Verify reset clears prev_cmd (double-buffering state)
8824        assert!(
8825            reset_body.contains("self.prev_cmd = None"),
8826            "reset() should clear prev_cmd for clean command buffer state"
8827        );
8828    }
8829
8830    #[test]
8831    fn generated_serve_handles_empty_messages_gracefully() {
8832        // The serve endpoint should not crash when receiving an empty messages array.
8833        // The format_chat_messages function should handle this gracefully.
8834        let config = minimal_config();
8835        let main_rs = generate_main_rs("test-model", &config).unwrap();
8836
8837        // The format_chat_messages function should exist and handle empty input
8838        let format_fn_start = main_rs
8839            .find("fn format_chat_messages")
8840            .expect("format_chat_messages must exist");
8841        let format_fn_body =
8842            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8843
8844        // It should iterate over messages (an empty slice produces an empty loop)
8845        assert!(
8846            format_fn_body.contains("for msg in messages"),
8847            "format_chat_messages should iterate over the messages slice"
8848        );
8849        // It should always append the assistant prompt suffix
8850        assert!(
8851            format_fn_body.contains("<|im_start|>assistant"),
8852            "format_chat_messages should always append assistant prompt header"
8853        );
8854
8855        // The serve function should call model.reset() before each request
8856        let serve_fn_start = main_rs
8857            .find("fn serve(")
8858            .expect("serve function must exist");
8859        let serve_fn_body = &main_rs[serve_fn_start..];
8860        assert!(
8861            serve_fn_body.contains("model.reset()"),
8862            "serve function should reset model between requests"
8863        );
8864    }
8865
8866    #[test]
8867    fn generated_model_forward_increments_pos() {
8868        // Each forward() call must increment self.pos so the next token
8869        // uses the correct RoPE position and KV cache offset.
8870        let config = minimal_config();
8871        let model_rs = generate_model_rs(&config).unwrap();
8872
8873        let forward_start = model_rs
8874            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8875            .unwrap();
8876        let forward_body = &model_rs[forward_start..];
8877        let forward_end = forward_body
8878            .find("\n    pub fn forward_profile")
8879            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8880            .or_else(|| forward_body.find("\n    fn dispatch_"))
8881            .unwrap_or(forward_body.len());
8882        let forward_code = &forward_body[..forward_end];
8883
8884        assert!(
8885            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8886            "forward() must increment self.pos after processing a token"
8887        );
8888    }
8889}