Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel. For standard LLM architectures this is
118    // max(hidden_size, intermediate_size). Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    // For models with intermediate_size > 8192 (e.g. Gemma-2B with 16384),
121    // the per-token matmul_q8/matmul_q8_batch dispatch helpers route the
122    // down-proj call through matmul_q8_gemm_batch instead, which reads
123    // inputs directly from device memory without caching.
124    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
125    // The attention-scores shared array must be at least effective_seq_len
126    // elements; anything smaller silently overflows for long prompts.  For
127    // small-context models (135M at 2K), a smaller array saves threadgroup
128    // memory and improves occupancy — so we size it precisely.
129    let attn_scores_size = config.max_seq_len.min(4096);
130    r#"//
131// Auto-generated by ForgeLLM Metal codegen.
132// Metal Shading Language compute kernels for transformer inference.
133//
134// Optimized with simdgroup cooperative reductions, shared memory vector
135// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
136// lane, and fast:: math intrinsics for Apple Silicon throughput.
137//
138
139#include <metal_stdlib>
140#include <metal_simdgroup_matrix>
141using namespace metal;
142
143// ── Constants ───────────────────────────────────────────────────────────
144// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
145// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
146// per shared memory load vs 4-row, improving ILP and reducing launches.
147constant constexpr uint SIMDGROUPS_PER_TG = 8;
148constant constexpr uint ROWS_PER_SIMDGROUP = 8;
149constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
150
151// ── matmul_vec ──────────────────────────────────────────────────────────
152// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
153// Uses simdgroup cooperative dot product with shared memory vector caching
154// and float4 vectorized loads. Each simdgroup processes 8 rows for better
155// shared memory reuse (8x vector reuse per load) and instruction-level
156// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
157kernel void matmul_vec(
158    device const float* matrix [[buffer(0)]],
159    device const float* vector [[buffer(1)]],
160    device float* output       [[buffer(2)]],
161    constant uint& rows        [[buffer(3)]],
162    constant uint& cols        [[buffer(4)]],
163    uint tgid [[threadgroup_position_in_grid]],
164    uint tid [[thread_index_in_threadgroup]],
165    uint simd_lane [[thread_index_in_simdgroup]],
166    uint simd_id [[simdgroup_index_in_threadgroup]])
167{
168    // Cooperatively load vector into threadgroup shared memory
169    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
170    for (uint i = tid; i < cols; i += 256) {
171        vec_tile[i] = vector[i];
172    }
173    threadgroup_barrier(mem_flags::mem_threadgroup);
174
175    // Each simdgroup handles 8 consecutive rows
176    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
177    if (row_base >= rows) return;
178
179    uint base0 = row_base * cols;
180    uint base1 = (row_base + 1) * cols;
181    uint base2 = (row_base + 2) * cols;
182    uint base3 = (row_base + 3) * cols;
183    uint base4 = (row_base + 4) * cols;
184    uint base5 = (row_base + 5) * cols;
185    uint base6 = (row_base + 6) * cols;
186    uint base7 = (row_base + 7) * cols;
187
188    // float4 vectorized accumulation across 8 rows
189    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
190    float4 sum4_0 = float4(0.0f);
191    float4 sum4_1 = float4(0.0f);
192    float4 sum4_2 = float4(0.0f);
193    float4 sum4_3 = float4(0.0f);
194    float4 sum4_4 = float4(0.0f);
195    float4 sum4_5 = float4(0.0f);
196    float4 sum4_6 = float4(0.0f);
197    float4 sum4_7 = float4(0.0f);
198
199    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
200        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
201        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
202        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
203        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
204        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
205        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
206        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
207        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
208        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
209    }
210
211    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
212    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
213    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
214    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
215    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
216    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
217    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
218    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
219
220    // Handle remaining elements (cols not divisible by 128)
221    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
222        float vv = vec_tile[j];
223        sum0 += matrix[base0 + j] * vv;
224        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
225        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
226        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
227        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
228        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
229        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
230        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
231    }
232
233    // Simdgroup hardware warp-level reduction
234    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
235    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
236    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
237    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
238
239    // Only first lane writes the results
240    if (simd_lane == 0) {
241        if (row_base     < rows) output[row_base]     = sum0;
242        if (row_base + 1 < rows) output[row_base + 1] = sum1;
243        if (row_base + 2 < rows) output[row_base + 2] = sum2;
244        if (row_base + 3 < rows) output[row_base + 3] = sum3;
245        if (row_base + 4 < rows) output[row_base + 4] = sum4;
246        if (row_base + 5 < rows) output[row_base + 5] = sum5;
247        if (row_base + 6 < rows) output[row_base + 6] = sum6;
248        if (row_base + 7 < rows) output[row_base + 7] = sum7;
249    }
250}
251
252// ── rms_norm ────────────────────────────────────────────────────────────
253// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
254// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
255// via shared memory for minimal synchronization overhead.
256kernel void rms_norm(
257    device const float* input   [[buffer(0)]],
258    device const float* weight  [[buffer(1)]],
259    device float* output        [[buffer(2)]],
260    constant uint& n            [[buffer(3)]],
261    constant float& eps         [[buffer(4)]],
262    uint tid [[thread_index_in_threadgroup]])
263{
264    // Each thread accumulates partial sum-of-squares
265    float sum_sq = 0.0f;
266    for (uint i = tid; i < n; i += 256) {
267        float v = input[i];
268        sum_sq += v * v;
269    }
270
271    // Simdgroup-level reduction (hardware warp sum)
272    sum_sq = simd_sum(sum_sq);
273
274    // Cross-simdgroup reduction via shared memory
275    threadgroup float shared[8];
276    uint simd_id = tid / 32;
277    uint simd_lane = tid % 32;
278    if (simd_lane == 0) {
279        shared[simd_id] = sum_sq;
280    }
281    threadgroup_barrier(mem_flags::mem_threadgroup);
282
283    // First thread computes final inverse RMS
284    if (tid == 0) {
285        float total = 0.0f;
286        for (uint i = 0; i < 8; i++) {
287            total += shared[i];
288        }
289        shared[0] = fast::rsqrt(total / float(n) + eps);
290    }
291    threadgroup_barrier(mem_flags::mem_threadgroup);
292
293    float inv_rms = shared[0];
294
295    // Normalize
296    for (uint i = tid; i < n; i += 256) {
297        output[i] = input[i] * inv_rms * weight[i];
298    }
299}
300
301// ── rope ────────────────────────────────────────────────────────────────
302// Rotary Position Embedding applied in-place.
303// Each thread handles one (head, pair) combination.
304kernel void rope(
305    device float* data        [[buffer(0)]],
306    constant uint& num_heads  [[buffer(1)]],
307    constant uint& head_dim   [[buffer(2)]],
308    constant uint& pos        [[buffer(3)]],
309    constant float& theta     [[buffer(4)]],
310    uint id [[thread_position_in_grid]])
311{
312    uint half_dim = head_dim / 2;
313    uint total_pairs = num_heads * half_dim;
314    if (id >= total_pairs) return;
315
316    uint h = id / half_dim;
317    uint i = id % half_dim;
318    uint off = h * head_dim;
319
320    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
321    float angle = float(pos) * freq;
322    float c = cos(angle);
323    float s = sin(angle);
324
325    float x0 = data[off + 2 * i];
326    float x1 = data[off + 2 * i + 1];
327    data[off + 2 * i]     = x0 * c - x1 * s;
328    data[off + 2 * i + 1] = x0 * s + x1 * c;
329}
330
331// ── softmax ─────────────────────────────────────────────────────────────
332// Numerically stable softmax over a 1-D array.
333// Single-threadgroup kernel with cooperative reduction.
334kernel void softmax(
335    device float* data       [[buffer(0)]],
336    constant uint& n         [[buffer(1)]],
337    uint tid [[thread_index_in_threadgroup]],
338    uint tg_size [[threads_per_threadgroup]])
339{
340    threadgroup float shared_val[256];
341
342    // Pass 1: find max
343    float local_max = -INFINITY;
344    for (uint i = tid; i < n; i += tg_size) {
345        local_max = max(local_max, data[i]);
346    }
347    shared_val[tid] = local_max;
348    threadgroup_barrier(mem_flags::mem_threadgroup);
349
350    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
351        if (tid < stride) {
352            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
353        }
354        threadgroup_barrier(mem_flags::mem_threadgroup);
355    }
356    float max_val = shared_val[0];
357    threadgroup_barrier(mem_flags::mem_threadgroup);
358
359    // Pass 2: exp and sum
360    float local_sum = 0.0f;
361    for (uint i = tid; i < n; i += tg_size) {
362        float e = fast::exp(data[i] - max_val);
363        data[i] = e;
364        local_sum += e;
365    }
366    shared_val[tid] = local_sum;
367    threadgroup_barrier(mem_flags::mem_threadgroup);
368
369    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
370        if (tid < stride) {
371            shared_val[tid] += shared_val[tid + stride];
372        }
373        threadgroup_barrier(mem_flags::mem_threadgroup);
374    }
375    float inv_sum = 1.0f / shared_val[0];
376    threadgroup_barrier(mem_flags::mem_threadgroup);
377
378    // Pass 3: normalize
379    for (uint i = tid; i < n; i += tg_size) {
380        data[i] *= inv_sum;
381    }
382}
383
384// ── silu_mul ────────────────────────────────────────────────────────────
385// Fused SiLU activation * element-wise multiply:
386//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
387kernel void silu_mul(
388    device const float* gate [[buffer(0)]],
389    device const float* up   [[buffer(1)]],
390    device float* output     [[buffer(2)]],
391    constant uint& n         [[buffer(3)]],
392    uint id [[thread_position_in_grid]])
393{
394    if (id >= n) return;
395    float g = gate[id];
396    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
397}
398
399// ── silu_mul_fused ─────────────────────────────────────────────────────
400// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
401//   gate = gate_up[0..n], up = gate_up[n..2*n]
402//   output[i] = silu(gate_up[i]) * gate_up[n + i]
403kernel void silu_mul_fused(
404    device const float* gate_up [[buffer(0)]],
405    device float* output        [[buffer(1)]],
406    constant uint& n            [[buffer(2)]],
407    uint id [[thread_position_in_grid]])
408{
409    if (id >= n) return;
410    float g = gate_up[id];
411    float u = gate_up[n + id];
412    output[id] = (g / (1.0f + fast::exp(-g))) * u;
413}
414
415// ── gelu_mul_fused ─────────────────────────────────────────────────────
416// Fused GELU-tanh-approx × multiply. Same layout as silu_mul_fused but
417// uses the tanh-approximate GELU (matches HF's `gelu_pytorch_tanh`, which
418// is what Gemma-1 uses):
419//   GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
420//   output[i] = GELU(gate_up[i]) * gate_up[n + i]
421kernel void gelu_mul_fused(
422    device const float* gate_up [[buffer(0)]],
423    device float* output        [[buffer(1)]],
424    constant uint& n            [[buffer(2)]],
425    uint id [[thread_position_in_grid]])
426{
427    if (id >= n) return;
428    constexpr float SQRT_2_OVER_PI = 0.7978845608f;
429    float g = gate_up[id];
430    float u = gate_up[n + id];
431    float inner = SQRT_2_OVER_PI * (g + 0.044715f * g * g * g);
432    float gelu = 0.5f * g * (1.0f + precise::tanh(inner));
433    output[id] = gelu * u;
434}
435
436// ── elementwise_add ─────────────────────────────────────────────────────
437// Residual connection: output[i] = a[i] + b[i]
438kernel void elementwise_add(
439    device const float* a  [[buffer(0)]],
440    device const float* b  [[buffer(1)]],
441    device float* output   [[buffer(2)]],
442    constant uint& n       [[buffer(3)]],
443    uint id [[thread_position_in_grid]])
444{
445    if (id >= n) return;
446    output[id] = a[id] + b[id];
447}
448
449// ── copy_buffer ─────────────────────────────────────────────────────────
450// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
451// transitions. Used for KV cache updates and embedding lookup.
452kernel void copy_buffer(
453    device const float* src [[buffer(0)]],
454    device float* dst       [[buffer(1)]],
455    constant uint& count    [[buffer(2)]],
456    uint id [[thread_position_in_grid]])
457{
458    if (id < count) dst[id] = src[id];
459}
460
461// ── copy_offset ─────────────────────────────────────────────────────────
462// Copy with source offset (in floats). Used for embedding table lookup
463// where we need to copy a specific row from a large table.
464kernel void copy_offset(
465    device const float* src     [[buffer(0)]],
466    device float* dst           [[buffer(1)]],
467    constant uint& src_offset   [[buffer(2)]],  // in floats
468    constant uint& count        [[buffer(3)]],
469    uint id [[thread_position_in_grid]])
470{
471    if (id < count) dst[id] = src[src_offset + id];
472}
473
474// ── copy_f32_to_f16_offset ──────────────────────────────────────────────
475// Copy f32 elements from src into a half-typed dst, converting to half on
476// write.  Used by the single-token decode path to append a new K/V vector
477// to the f16 KV cache.  Byte offsets into src/dst are supplied via the
478// Metal buffer binding offsets — no in-kernel offsets needed.
479kernel void copy_f32_to_f16_offset(
480    device const float* src     [[buffer(0)]],
481    device half* dst            [[buffer(1)]],
482    constant uint& count        [[buffer(2)]],
483    uint id [[thread_position_in_grid]])
484{
485    if (id < count) dst[id] = half(src[id]);
486}
487
488// ── add_inplace ─────────────────────────────────────────────────────────
489// In-place residual connection: a[i] += b[i]
490// Avoids a separate blit copy for residual add, reducing encoder overhead.
491kernel void add_inplace(
492    device float* a        [[buffer(0)]],
493    device const float* b  [[buffer(1)]],
494    constant uint& n       [[buffer(2)]],
495    uint id [[thread_position_in_grid]])
496{
497    if (id >= n) return;
498    a[id] += b[id];
499}
500
501// ── matmul_vec_q8 ─────────────────────────────────────────────────────
502// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
503// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
504// Operates directly on quantized weights to halve memory bandwidth vs f32,
505// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
506//
507// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
508// because int8->float conversion doubles register demand.  Fully unrolled
509// inner loop with float4 vector loads from shared memory eliminates loop
510// overhead and enables better instruction scheduling.
511// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
512constant constexpr uint Q8_ROWS_PER_SG = 4;
513constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
514
515// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
516// doubles ALU work, so the same register budget applies).
517constant constexpr uint Q4_ROWS_PER_SG = 4;
518constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
519
520kernel void matmul_vec_q8(
521    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
522    device const float* vector   [[buffer(1)]],  // f32 input
523    device float* output         [[buffer(2)]],
524    constant uint& rows          [[buffer(3)]],
525    constant uint& cols          [[buffer(4)]],  // number of elements per row
526    uint tgid [[threadgroup_position_in_grid]],
527    uint tid [[thread_index_in_threadgroup]],
528    uint simd_lane [[thread_index_in_simdgroup]],
529    uint simd_id [[simdgroup_index_in_threadgroup]])
530{
531    // Load vector into shared memory
532    threadgroup float vec_tile[VEC_TILE_SIZE];
533    for (uint i = tid; i < cols; i += 256) {
534        vec_tile[i] = vector[i];
535    }
536    threadgroup_barrier(mem_flags::mem_threadgroup);
537
538    // Each simdgroup handles 4 consecutive rows (lower register pressure)
539    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
540    if (row_base >= rows) return;
541
542    // Q8_0: each block is 34 bytes for 32 elements
543    uint blocks_per_row = cols / 32;
544    uint row_bytes = blocks_per_row * 34;
545
546    // Pointers to each row's Q8_0 data
547    device const uchar* r0 = matrix + row_base * row_bytes;
548    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
549    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
550    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
551
552    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
553
554    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
555        uint bb = blk * 34;
556        uint vb = blk * 32;
557
558        // Prefetch all 4 scales
559        float sc0 = float(*(device const half*)(r0 + bb));
560        float sc1 = float(*(device const half*)(r1 + bb));
561        float sc2 = float(*(device const half*)(r2 + bb));
562        float sc3 = float(*(device const half*)(r3 + bb));
563
564        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
565        // Q8_0 block layout where the int8 data starts at offset +2 from a
566        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
567        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
568        // reduction in memory transactions. Metal's char16/packed_char16 are
569        // reserved types and packed_*int4 require >=4-byte alignment which
570        // this layout does not provide, so packed_short4 is the widest valid
571        // vectorized load.
572        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
573        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
574        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
575        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
576
577        // Load all 8 float4 vector values for this 32-element block from shared memory
578        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
579        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
580        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
581        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
582        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
583        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
584        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
585        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
586
587        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
588        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
589        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
590            short4 _s = short4(SHORT4); \
591            char2 _a = as_type<char2>(_s.x); \
592            char2 _b = as_type<char2>(_s.y); \
593            char2 _c = as_type<char2>(_s.z); \
594            char2 _d = as_type<char2>(_s.w); \
595            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
596            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
597        }
598
599        float4 f0, f1;
600        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
601
602        // Row 0: 4 short4 loads cover 32 int8 weights
603        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
604        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
605        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
606        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
607
608        // Row 1
609        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
610        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
611        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
612        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
613
614        // Row 2
615        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
616        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
617        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
618        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
619
620        // Row 3
621        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
622        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
623        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
624        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
625
626        #undef Q8_UNPACK8
627
628        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
629    }
630
631    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
632    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
633
634    if (simd_lane == 0) {
635        if (row_base     < rows) output[row_base]     = sum0;
636        if (row_base + 1 < rows) output[row_base + 1] = sum1;
637        if (row_base + 2 < rows) output[row_base + 2] = sum2;
638        if (row_base + 3 < rows) output[row_base + 3] = sum3;
639    }
640}
641
642// ── matmul_vec_q4 ─────────────────────────────────────────────────────
643// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
644// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
645// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
646// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
647//
648// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
649// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
650kernel void matmul_vec_q4(
651    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
652    device const float* vector   [[buffer(1)]],  // f32 input
653    device float* output         [[buffer(2)]],
654    constant uint& rows          [[buffer(3)]],
655    constant uint& cols          [[buffer(4)]],  // number of elements per row
656    uint tgid [[threadgroup_position_in_grid]],
657    uint tid [[thread_index_in_threadgroup]],
658    uint simd_lane [[thread_index_in_simdgroup]],
659    uint simd_id [[simdgroup_index_in_threadgroup]])
660{
661    // Load vector into shared memory
662    threadgroup float vec_tile[VEC_TILE_SIZE];
663    for (uint i = tid; i < cols; i += 256) {
664        vec_tile[i] = vector[i];
665    }
666    threadgroup_barrier(mem_flags::mem_threadgroup);
667
668    // Each simdgroup handles 4 consecutive rows
669    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
670    if (row_base >= rows) return;
671
672    // Q4_0: each block is 18 bytes for 32 elements
673    uint blocks_per_row = cols / 32;
674    uint row_bytes = blocks_per_row * 18;
675
676    // Pointers to each row's Q4_0 data
677    device const uchar* r0 = matrix + row_base * row_bytes;
678    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
679    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
680    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
681
682    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
683
684    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
685        uint bb = blk * 18;
686        uint vb = blk * 32;
687
688        // Prefetch all 4 scales
689        float sc0 = float(*(device const half*)(r0 + bb));
690        float sc1 = float(*(device const half*)(r1 + bb));
691        float sc2 = float(*(device const half*)(r2 + bb));
692        float sc3 = float(*(device const half*)(r3 + bb));
693
694        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
695        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
696        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
697        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
698        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
699
700        // Load 8 float4 vector values for 32 elements from shared memory
701        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
702        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
703        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
704        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
705        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
706        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
707        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
708        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
709        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
710
711        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
712        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
713        float bd0=0, bd1=0, bd2=0, bd3=0;
714        uchar4 b;
715
716        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
717        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
718                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
719                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
720                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
721        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
722                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
723                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
724                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
725        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
726                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
727                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
728                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
729        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
730                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
731                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
732                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
733
734        // Row 1
735        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
736                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
737                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
738                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
739        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
740                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
741                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
742                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
743        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
744                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
745                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
746                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
747        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
748                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
749                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
750                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
751
752        // Row 2
753        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
754                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
755                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
756                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
757        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
758                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
759                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
760                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
761        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
762                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
763                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
764                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
765        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
766                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
767                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
768                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
769
770        // Row 3
771        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
772                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
773                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
774                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
775        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
776                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
777                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
778                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
779        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
780                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
781                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
782                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
783        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
784                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
785                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
786                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
787
788        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
789    }
790
791    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
792    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
793
794    if (simd_lane == 0) {
795        if (row_base     < rows) output[row_base]     = sum0;
796        if (row_base + 1 < rows) output[row_base + 1] = sum1;
797        if (row_base + 2 < rows) output[row_base + 2] = sum2;
798        if (row_base + 3 < rows) output[row_base + 3] = sum3;
799    }
800}
801
802// ── attention ───────────────────────────────────────────────────────────
803// Single-query attention with simdgroup cooperative reductions.
804// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
805// with simd_max/simd_sum reductions, then weighted sum of V.
806// Each threadgroup handles one head with 256 threads (8 simdgroups).
807//
808// Buffers:
809//   q:       [num_heads * head_dim]       current query
810//   k_cache: [max_seq_len * num_kv_heads * head_dim]
811//   v_cache: [max_seq_len * num_kv_heads * head_dim]
812//   output:  [num_heads * head_dim]
813kernel void attention(
814    device const float* q        [[buffer(0)]],
815    device const half*  k_cache  [[buffer(1)]],
816    device const half*  v_cache  [[buffer(2)]],
817    device float* output         [[buffer(3)]],
818    constant uint& seq_len       [[buffer(4)]],
819    constant uint& num_heads     [[buffer(5)]],
820    constant uint& num_kv_heads  [[buffer(6)]],
821    constant uint& head_dim      [[buffer(7)]],
822    uint tgid [[threadgroup_position_in_grid]],
823    uint tid [[thread_index_in_threadgroup]],
824    uint simd_lane [[thread_index_in_simdgroup]],
825    uint simd_id [[simdgroup_index_in_threadgroup]])
826{
827    uint head = tgid;
828    if (head >= num_heads) return;
829    uint kv_head = head / (num_heads / num_kv_heads);
830
831    uint q_off = head * head_dim;
832
833    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
834    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
835    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
836    // most generation steps have short effective context.
837    threadgroup float scores[ATTN_SCORES_SIZE];  // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
838
839    // Q·K^T with half4/float4 vectorized loads.
840    // Each simdgroup handles one s; 32 lanes cover head_dim in chunks of 4.
841    // For head_dim=128, every lane does exactly one half4 load (no loop).
842    // For head_dim=64, 16 lanes active; for head_dim=96, 24 lanes.
843    uint head_dim4 = head_dim / 4;
844    for (uint s = simd_id; s < seq_len; s += 8) {
845        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
846        float dot = 0.0;
847        for (uint d4 = simd_lane; d4 < head_dim4; d4 += 32) {
848            uint d = d4 * 4;
849            float4 q4 = *(device const float4*)(q + q_off + d);
850            float4 k4 = float4(*(device const half4*)(k_cache + k_off + d));
851            dot += q4.x * k4.x + q4.y * k4.y + q4.z * k4.z + q4.w * k4.w;
852        }
853        // Scalar fallback for head_dim not divisible by 4 (unused for all
854        // current models: head_dim ∈ {64, 96, 128} are all multiples of 4).
855        for (uint d = head_dim4 * 4 + simd_lane; d < head_dim; d += 32) {
856            dot += q[q_off + d] * float(k_cache[k_off + d]);
857        }
858        dot = simd_sum(dot);
859        if (simd_lane == 0) {
860            scores[s] = dot * fast::rsqrt(float(head_dim));
861        }
862    }
863    threadgroup_barrier(mem_flags::mem_threadgroup);
864
865    // Step 2: Softmax over scores (cooperative)
866    // Find max
867    float local_max = -INFINITY;
868    for (uint s = tid; s < seq_len; s += 256) {
869        local_max = max(local_max, scores[s]);
870    }
871    local_max = simd_max(local_max);
872    threadgroup float shared_max[8];
873    if (simd_lane == 0) shared_max[simd_id] = local_max;
874    threadgroup_barrier(mem_flags::mem_threadgroup);
875    if (tid == 0) {
876        float m = shared_max[0];
877        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
878        shared_max[0] = m;
879    }
880    threadgroup_barrier(mem_flags::mem_threadgroup);
881    float max_val = shared_max[0];
882
883    // Exp and sum
884    float local_sum = 0.0;
885    for (uint s = tid; s < seq_len; s += 256) {
886        scores[s] = fast::exp(scores[s] - max_val);
887        local_sum += scores[s];
888    }
889    local_sum = simd_sum(local_sum);
890    threadgroup float shared_sum[8];
891    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
892    threadgroup_barrier(mem_flags::mem_threadgroup);
893    if (tid == 0) {
894        float total = 0.0;
895        for (uint i = 0; i < 8; i++) total += shared_sum[i];
896        shared_sum[0] = 1.0 / total;
897    }
898    threadgroup_barrier(mem_flags::mem_threadgroup);
899    float inv_sum = shared_sum[0];
900
901    for (uint s = tid; s < seq_len; s += 256) {
902        scores[s] *= inv_sum;
903    }
904    threadgroup_barrier(mem_flags::mem_threadgroup);
905
906    // Step 3: Weighted sum of V: output = scores · V.
907    //
908    // v0.7.4 restructure: one simdgroup per d4 chunk. The 32 lanes in a
909    // simdgroup partition the seq_len dimension (lane handles positions
910    // s = simd_lane, simd_lane+32, simd_lane+64, ...) and reduce their
911    // partial sums via simd_sum.  This keeps all 256 threads productive
912    // for head_dim ∈ {64, 96, 128}; the old "one thread per d4" layout
913    // kept only 16/256 threads productive for head_dim=64.
914    //
915    // 8 simdgroups handle 8 d4s per outer iteration:
916    //   head_dim=64  (head_dim4=16): 2 iters
917    //   head_dim=96  (head_dim4=24): 3 iters
918    //   head_dim=128 (head_dim4=32): 4 iters
919    uint v_stride = num_kv_heads * head_dim;
920    for (uint d4 = simd_id; d4 < head_dim4; d4 += 8) {
921        uint d = d4 * 4;
922        uint v_base = kv_head * head_dim + d;
923        float4 partial = float4(0.0);
924        for (uint s = simd_lane; s < seq_len; s += 32) {
925            half4 v = *(device const half4*)(v_cache + s * v_stride + v_base);
926            partial += scores[s] * float4(v);
927        }
928        // Reduce the 32 lanes of this simdgroup to lane 0 per component.
929        partial.x = simd_sum(partial.x);
930        partial.y = simd_sum(partial.y);
931        partial.z = simd_sum(partial.z);
932        partial.w = simd_sum(partial.w);
933        if (simd_lane == 0) {
934            *(device float4*)(output + q_off + d) = partial;
935        }
936    }
937    // Scalar fallback for dims not divisible by 4 (unused for current models —
938    // all supported head_dims are in {64, 96, 128}). Same simd-per-d layout.
939    for (uint d = head_dim4 * 4 + simd_id; d < head_dim; d += 8) {
940        uint v_base = kv_head * head_dim + d;
941        float partial = 0.0;
942        for (uint s = simd_lane; s < seq_len; s += 32) {
943            partial += scores[s] * float(v_cache[s * v_stride + v_base]);
944        }
945        partial = simd_sum(partial);
946        if (simd_lane == 0) {
947            output[q_off + d] = partial;
948        }
949    }
950}
951
952// ── Batched prefill kernels ────────────────────────────────────────────
953// These kernels process M input vectors against the same weight matrix
954// in a single dispatch, converting mat-vec into mat-mat for better GPU
955// utilization during prompt prefill.
956
957// ── rms_norm_batch ─────────────────────────────────────────────────────
958// RMS normalization for a batch of vectors.
959// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
960// Grid: M threadgroups (one per token).
961kernel void rms_norm_batch(
962    device const float* input   [[buffer(0)]],  // [M, n]
963    device const float* weight  [[buffer(1)]],  // [n]
964    device float* output        [[buffer(2)]],  // [M, n]
965    constant uint& n            [[buffer(3)]],
966    constant float& eps         [[buffer(4)]],
967    constant uint& num_tokens   [[buffer(5)]],
968    uint tgid [[threadgroup_position_in_grid]],
969    uint tid [[thread_index_in_threadgroup]])
970{
971    if (tgid >= num_tokens) return;
972
973    uint base = tgid * n;
974
975    float sum_sq = 0.0f;
976    for (uint i = tid; i < n; i += 256) {
977        float v = input[base + i];
978        sum_sq += v * v;
979    }
980
981    sum_sq = simd_sum(sum_sq);
982
983    threadgroup float shared[8];
984    uint simd_id = tid / 32;
985    uint simd_lane = tid % 32;
986    if (simd_lane == 0) {
987        shared[simd_id] = sum_sq;
988    }
989    threadgroup_barrier(mem_flags::mem_threadgroup);
990
991    if (tid == 0) {
992        float total = 0.0f;
993        for (uint i = 0; i < 8; i++) {
994            total += shared[i];
995        }
996        shared[0] = fast::rsqrt(total / float(n) + eps);
997    }
998    threadgroup_barrier(mem_flags::mem_threadgroup);
999
1000    float inv_rms = shared[0];
1001
1002    for (uint i = tid; i < n; i += 256) {
1003        output[base + i] = input[base + i] * inv_rms * weight[i];
1004    }
1005}
1006
1007// ── rope_batch ─────────────────────────────────────────────────────────
1008// Rotary Position Embedding for a batch of vectors with different positions.
1009// data layout: [M, num_heads * head_dim], positions: [M]
1010// Each thread handles one (token, head, pair) combination.
1011kernel void rope_batch(
1012    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
1013    constant uint& num_heads     [[buffer(1)]],
1014    constant uint& head_dim      [[buffer(2)]],
1015    device const uint* positions  [[buffer(3)]],  // [M] position per token
1016    constant float& theta        [[buffer(4)]],
1017    constant uint& num_tokens    [[buffer(5)]],
1018    uint id [[thread_position_in_grid]])
1019{
1020    uint half_dim = head_dim / 2;
1021    uint pairs_per_token = num_heads * half_dim;
1022    uint total = num_tokens * pairs_per_token;
1023    if (id >= total) return;
1024
1025    uint token = id / pairs_per_token;
1026    uint rem = id % pairs_per_token;
1027    uint h = rem / half_dim;
1028    uint i = rem % half_dim;
1029    uint off = token * (num_heads * head_dim) + h * head_dim;
1030
1031    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
1032    float angle = float(positions[token]) * freq;
1033    float c = cos(angle);
1034    float s = sin(angle);
1035
1036    float x0 = data[off + 2 * i];
1037    float x1 = data[off + 2 * i + 1];
1038    data[off + 2 * i]     = x0 * c - x1 * s;
1039    data[off + 2 * i + 1] = x0 * s + x1 * c;
1040}
1041
1042// ── silu_mul_fused_batch ───────────────────────────────────────────────
1043// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
1044// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
1045kernel void silu_mul_fused_batch(
1046    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
1047    device float* output        [[buffer(1)]],  // [M, n]
1048    constant uint& n            [[buffer(2)]],
1049    constant uint& num_tokens   [[buffer(3)]],
1050    uint id [[thread_position_in_grid]])
1051{
1052    uint total = num_tokens * n;
1053    if (id >= total) return;
1054    uint token = id / n;
1055    uint i = id % n;
1056    uint gu_base = token * 2 * n;
1057    float g = gate_up[gu_base + i];
1058    float u = gate_up[gu_base + n + i];
1059    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
1060}
1061
1062// ── gelu_mul_fused_batch ───────────────────────────────────────────────
1063// Fused GELU-tanh-approx × multiply for a batch (Gemma-1 FFN).
1064// Same layout as silu_mul_fused_batch but with `gelu_pytorch_tanh`.
1065kernel void gelu_mul_fused_batch(
1066    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
1067    device float* output        [[buffer(1)]],  // [M, n]
1068    constant uint& n            [[buffer(2)]],
1069    constant uint& num_tokens   [[buffer(3)]],
1070    uint id [[thread_position_in_grid]])
1071{
1072    uint total = num_tokens * n;
1073    if (id >= total) return;
1074    uint token = id / n;
1075    uint i = id % n;
1076    uint gu_base = token * 2 * n;
1077    constexpr float SQRT_2_OVER_PI = 0.7978845608f;
1078    float g = gate_up[gu_base + i];
1079    float u = gate_up[gu_base + n + i];
1080    float inner = SQRT_2_OVER_PI * (g + 0.044715f * g * g * g);
1081    float gelu = 0.5f * g * (1.0f + precise::tanh(inner));
1082    output[token * n + i] = gelu * u;
1083}
1084
1085// ── add_inplace_batch ──────────────────────────────────────────────────
1086// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
1087kernel void add_inplace_batch(
1088    device float* a        [[buffer(0)]],  // [M * n]
1089    device const float* b  [[buffer(1)]],  // [M * n]
1090    constant uint& total   [[buffer(2)]],  // M * n
1091    uint id [[thread_position_in_grid]])
1092{
1093    if (id >= total) return;
1094    a[id] += b[id];
1095}
1096
1097// ── scale_buffer ───────────────────────────────────────────────────────
1098// Multiply every element of a buffer by a scalar in place. Used by the
1099// Gemma-1 prefill path to scale embeddings by `sqrt(hidden_size)` after
1100// the batched embed lookup (Gemma-1 scales embeddings right after lookup
1101// in the HF reference forward pass).
1102kernel void scale_buffer(
1103    device float* data      [[buffer(0)]],
1104    constant float& scale   [[buffer(1)]],
1105    constant uint& count    [[buffer(2)]],
1106    uint id [[thread_position_in_grid]])
1107{
1108    if (id >= count) return;
1109    data[id] *= scale;
1110}
1111
1112// ── copy_embedding_batch ───────────────────────────────────────────────
1113// Copy M embedding rows from embedding table to a contiguous batch buffer.
1114// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1115kernel void copy_embedding_batch(
1116    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1117    device float* output        [[buffer(1)]],  // [M, dim]
1118    device const uint* tokens   [[buffer(2)]],  // [M]
1119    constant uint& dim          [[buffer(3)]],
1120    constant uint& num_tokens   [[buffer(4)]],
1121    uint id [[thread_position_in_grid]])
1122{
1123    uint total = num_tokens * dim;
1124    if (id >= total) return;
1125    uint token_idx = id / dim;
1126    uint d = id % dim;
1127    output[id] = embed[tokens[token_idx] * dim + d];
1128}
1129
1130// ── matmul_vec_batch ───────────────────────────────────────────────────
1131// Batched matrix-vector multiply: process M input vectors against the same
1132// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1133// Each threadgroup handles one (token, row_group) pair.
1134kernel void matmul_vec_batch(
1135    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1136    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1137    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1138    constant uint& num_tokens   [[buffer(3)]],  // M
1139    constant uint& rows         [[buffer(4)]],
1140    constant uint& cols         [[buffer(5)]],
1141    uint tgid [[threadgroup_position_in_grid]],
1142    uint tid [[thread_index_in_threadgroup]],
1143    uint simd_lane [[thread_index_in_simdgroup]],
1144    uint simd_id [[simdgroup_index_in_threadgroup]])
1145{
1146    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1147    uint token = tgid / row_tgs;
1148    uint tg_in_token = tgid % row_tgs;
1149    if (token >= num_tokens) return;
1150
1151    // Load this token's input vector into shared memory
1152    threadgroup float vec_tile[VEC_TILE_SIZE];
1153    device const float* input = inputs + token * cols;
1154    for (uint i = tid; i < cols; i += 256) {
1155        vec_tile[i] = input[i];
1156    }
1157    threadgroup_barrier(mem_flags::mem_threadgroup);
1158
1159    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1160    if (row_base >= rows) return;
1161
1162    uint base0 = row_base * cols;
1163    uint base1 = (row_base + 1) * cols;
1164    uint base2 = (row_base + 2) * cols;
1165    uint base3 = (row_base + 3) * cols;
1166    uint base4 = (row_base + 4) * cols;
1167    uint base5 = (row_base + 5) * cols;
1168    uint base6 = (row_base + 6) * cols;
1169    uint base7 = (row_base + 7) * cols;
1170
1171    uint cols_vec4 = cols & ~127u;
1172    float4 sum4_0 = float4(0.0f);
1173    float4 sum4_1 = float4(0.0f);
1174    float4 sum4_2 = float4(0.0f);
1175    float4 sum4_3 = float4(0.0f);
1176    float4 sum4_4 = float4(0.0f);
1177    float4 sum4_5 = float4(0.0f);
1178    float4 sum4_6 = float4(0.0f);
1179    float4 sum4_7 = float4(0.0f);
1180
1181    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1182        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1183        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1184        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1185        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1186        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1187        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1188        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1189        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1190        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1191    }
1192
1193    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1194    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1195    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1196    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1197    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1198    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1199    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1200    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1201
1202    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1203        float vv = vec_tile[j];
1204        sum0 += matrix[base0 + j] * vv;
1205        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1206        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1207        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1208        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1209        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1210        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1211        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1212    }
1213
1214    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1215    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1216    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1217    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1218
1219    device float* output = outputs + token * rows;
1220    if (simd_lane == 0) {
1221        if (row_base     < rows) output[row_base]     = sum0;
1222        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1223        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1224        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1225        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1226        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1227        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1228        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1229    }
1230}
1231
1232// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1233// Batched Q8_0 matrix-vector multiply for M input vectors.
1234// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1235kernel void matmul_vec_q8_batch(
1236    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1237    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1238    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1239    constant uint& num_tokens    [[buffer(3)]],  // M
1240    constant uint& rows          [[buffer(4)]],
1241    constant uint& cols          [[buffer(5)]],
1242    uint tgid [[threadgroup_position_in_grid]],
1243    uint tid [[thread_index_in_threadgroup]],
1244    uint simd_lane [[thread_index_in_simdgroup]],
1245    uint simd_id [[simdgroup_index_in_threadgroup]])
1246{
1247    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1248    uint token = tgid / row_tgs;
1249    uint tg_in_token = tgid % row_tgs;
1250    if (token >= num_tokens) return;
1251
1252    // Load this token's input vector into shared memory
1253    threadgroup float vec_tile[VEC_TILE_SIZE];
1254    device const float* input = inputs + token * cols;
1255    for (uint i = tid; i < cols; i += 256) {
1256        vec_tile[i] = input[i];
1257    }
1258    threadgroup_barrier(mem_flags::mem_threadgroup);
1259
1260    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1261    if (row_base >= rows) return;
1262
1263    uint blocks_per_row = cols / 32;
1264    uint row_bytes = blocks_per_row * 34;
1265
1266    device const uchar* r0 = matrix + row_base * row_bytes;
1267    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1268    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1269    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1270
1271    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1272
1273    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1274        uint bb = blk * 34;
1275        uint vb = blk * 32;
1276
1277        float sc0 = float(*(device const half*)(r0 + bb));
1278        float sc1 = float(*(device const half*)(r1 + bb));
1279        float sc2 = float(*(device const half*)(r2 + bb));
1280        float sc3 = float(*(device const half*)(r3 + bb));
1281
1282        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1283        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1284        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1285        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1286        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1287        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1288
1289        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1290        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1291        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1292        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1293        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1294        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1295        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1296        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1297
1298        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1299            short4 _s = short4(SHORT4); \
1300            char2 _a = as_type<char2>(_s.x); \
1301            char2 _b = as_type<char2>(_s.y); \
1302            char2 _c = as_type<char2>(_s.z); \
1303            char2 _d = as_type<char2>(_s.w); \
1304            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1305            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1306        }
1307
1308        float4 f0, f1;
1309        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1310
1311        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1312        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1313        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1314        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1315
1316        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1317        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1318        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1319        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1320
1321        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1322        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1323        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1324        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1325
1326        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1327        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1328        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1329        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1330
1331        #undef Q8_UNPACK8
1332
1333        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1334    }
1335
1336    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1337    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1338
1339    device float* output = outputs + token * rows;
1340    if (simd_lane == 0) {
1341        if (row_base     < rows) output[row_base]     = sum0;
1342        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1343        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1344        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1345    }
1346}
1347
1348// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1349// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1350// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1351// the Q8_0 weight blocks are fetched once from device memory and reused for
1352// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1353// per-token dispatch).
1354//
1355// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1356// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1357// with simd_sum.  Token vectors are read directly from device memory inside
1358// the block loop (not cached in shared memory) so intermediate_size up to
1359// 8192 fits without spilling threadgroup memory.
1360constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1361
1362kernel void matmul_q8_gemm_batch(
1363    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1364    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1365    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1366    constant uint& num_tokens    [[buffer(3)]],  // M
1367    constant uint& rows          [[buffer(4)]],
1368    constant uint& cols          [[buffer(5)]],
1369    uint2 tgid [[threadgroup_position_in_grid]],
1370    uint tid [[thread_index_in_threadgroup]],
1371    uint simd_lane [[thread_index_in_simdgroup]],
1372    uint simd_id [[simdgroup_index_in_threadgroup]])
1373{
1374    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1375    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1376    if (row_base >= rows || tok_base >= num_tokens) return;
1377
1378    // How many tokens in this tile are valid?
1379    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1380
1381    uint blocks_per_row = cols / 32;
1382    uint row_bytes = blocks_per_row * 34;
1383
1384    device const uchar* r0 = matrix + row_base * row_bytes;
1385    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1386    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1387    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1388
1389    // Accumulators: 4 tokens × 4 rows per simdgroup.
1390    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1391    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1392    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1393    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1394
1395    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1396        uint bb = blk * 34;
1397        uint vb = blk * 32;
1398
1399        // ── Load weight data ONCE per block (reused across all tokens) ──
1400        float sc0 = float(*(device const half*)(r0 + bb));
1401        float sc1 = float(*(device const half*)(r1 + bb));
1402        float sc2 = float(*(device const half*)(r2 + bb));
1403        float sc3 = float(*(device const half*)(r3 + bb));
1404
1405        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1406        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1407        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1408        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1409
1410        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1411            short4 _s = short4(SHORT4); \
1412            char2 _a = as_type<char2>(_s.x); \
1413            char2 _b = as_type<char2>(_s.y); \
1414            char2 _c = as_type<char2>(_s.z); \
1415            char2 _d = as_type<char2>(_s.w); \
1416            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1417            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1418        }
1419
1420        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1421        // registers for the duration of the block and are dotted against
1422        // every token's vector tile.
1423        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1424        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1425        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1426        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1427
1428        Q8_UNPACK8(d0[0], w0_0, w0_1);
1429        Q8_UNPACK8(d0[1], w0_2, w0_3);
1430        Q8_UNPACK8(d0[2], w0_4, w0_5);
1431        Q8_UNPACK8(d0[3], w0_6, w0_7);
1432
1433        Q8_UNPACK8(d1[0], w1_0, w1_1);
1434        Q8_UNPACK8(d1[1], w1_2, w1_3);
1435        Q8_UNPACK8(d1[2], w1_4, w1_5);
1436        Q8_UNPACK8(d1[3], w1_6, w1_7);
1437
1438        Q8_UNPACK8(d2[0], w2_0, w2_1);
1439        Q8_UNPACK8(d2[1], w2_2, w2_3);
1440        Q8_UNPACK8(d2[2], w2_4, w2_5);
1441        Q8_UNPACK8(d2[3], w2_6, w2_7);
1442
1443        Q8_UNPACK8(d3[0], w3_0, w3_1);
1444        Q8_UNPACK8(d3[1], w3_2, w3_3);
1445        Q8_UNPACK8(d3[2], w3_4, w3_5);
1446        Q8_UNPACK8(d3[3], w3_6, w3_7);
1447
1448        #undef Q8_UNPACK8
1449
1450        // ── For each token, read vector and accumulate against shared weights ──
1451        // Token 0 (always valid: tok_count >= 1).
1452        {
1453            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1454            float4 v0 = *(device const float4*)(a0);
1455            float4 v1 = *(device const float4*)(a0 + 4);
1456            float4 v2 = *(device const float4*)(a0 + 8);
1457            float4 v3 = *(device const float4*)(a0 + 12);
1458            float4 v4 = *(device const float4*)(a0 + 16);
1459            float4 v5 = *(device const float4*)(a0 + 20);
1460            float4 v6 = *(device const float4*)(a0 + 24);
1461            float4 v7 = *(device const float4*)(a0 + 28);
1462            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1463                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1464            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1465                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1466            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1467                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1468            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1469                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1470            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1471        }
1472        // Token 1
1473        if (tok_count > 1) {
1474            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1475            float4 v0 = *(device const float4*)(a1);
1476            float4 v1 = *(device const float4*)(a1 + 4);
1477            float4 v2 = *(device const float4*)(a1 + 8);
1478            float4 v3 = *(device const float4*)(a1 + 12);
1479            float4 v4 = *(device const float4*)(a1 + 16);
1480            float4 v5 = *(device const float4*)(a1 + 20);
1481            float4 v6 = *(device const float4*)(a1 + 24);
1482            float4 v7 = *(device const float4*)(a1 + 28);
1483            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1484                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1485            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1486                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1487            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1488                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1489            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1490                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1491            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1492        }
1493        // Token 2
1494        if (tok_count > 2) {
1495            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1496            float4 v0 = *(device const float4*)(a2);
1497            float4 v1 = *(device const float4*)(a2 + 4);
1498            float4 v2 = *(device const float4*)(a2 + 8);
1499            float4 v3 = *(device const float4*)(a2 + 12);
1500            float4 v4 = *(device const float4*)(a2 + 16);
1501            float4 v5 = *(device const float4*)(a2 + 20);
1502            float4 v6 = *(device const float4*)(a2 + 24);
1503            float4 v7 = *(device const float4*)(a2 + 28);
1504            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1505                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1506            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1507                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1508            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1509                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1510            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1511                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1512            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1513        }
1514        // Token 3
1515        if (tok_count > 3) {
1516            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1517            float4 v0 = *(device const float4*)(a3);
1518            float4 v1 = *(device const float4*)(a3 + 4);
1519            float4 v2 = *(device const float4*)(a3 + 8);
1520            float4 v3 = *(device const float4*)(a3 + 12);
1521            float4 v4 = *(device const float4*)(a3 + 16);
1522            float4 v5 = *(device const float4*)(a3 + 20);
1523            float4 v6 = *(device const float4*)(a3 + 24);
1524            float4 v7 = *(device const float4*)(a3 + 28);
1525            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1526                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1527            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1528                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1529            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1530                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1531            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1532                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1533            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1534        }
1535    }
1536
1537    // simdgroup reduction
1538    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1539    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1540    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1541    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1542
1543    if (simd_lane == 0) {
1544        device float* o0 = outputs + (tok_base + 0) * rows;
1545        if (row_base     < rows) o0[row_base]     = s00;
1546        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1547        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1548        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1549
1550        if (tok_count > 1) {
1551            device float* o1 = outputs + (tok_base + 1) * rows;
1552            if (row_base     < rows) o1[row_base]     = s10;
1553            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1554            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1555            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1556        }
1557        if (tok_count > 2) {
1558            device float* o2 = outputs + (tok_base + 2) * rows;
1559            if (row_base     < rows) o2[row_base]     = s20;
1560            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1561            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1562            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1563        }
1564        if (tok_count > 3) {
1565            device float* o3 = outputs + (tok_base + 3) * rows;
1566            if (row_base     < rows) o3[row_base]     = s30;
1567            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1568            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1569            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1570        }
1571    }
1572}
1573
1574// ── matmul_q8_mma ──────────────────────────────────────────────────────
1575// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1576// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1577// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1578// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1579//
1580// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1581// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1582// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1583// dequantized into threadgroup memory once per block and reused by all
1584// simdgroups in the tile.
1585//
1586// Assumptions (verified in the dispatch helper, falls back otherwise):
1587//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1588//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1589//   * num_tokens may be any value; partial row at the tile boundary is handled
1590//     via a scratch copy path.
1591constant constexpr uint MMA_TOK_TILE = 16;
1592constant constexpr uint MMA_ROW_TILE = 16;
1593
1594kernel void matmul_q8_mma(
1595    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1596    device const float* inputs   [[buffer(1)]],  // [M, cols]
1597    device float* outputs        [[buffer(2)]],  // [M, rows]
1598    constant uint& num_tokens    [[buffer(3)]],
1599    constant uint& rows          [[buffer(4)]],
1600    constant uint& cols          [[buffer(5)]],
1601    uint2 tgid [[threadgroup_position_in_grid]],
1602    uint tid [[thread_index_in_threadgroup]],
1603    uint simd_id [[simdgroup_index_in_threadgroup]])
1604{
1605    uint row_base = tgid.x * MMA_ROW_TILE;
1606    uint tok_base = tgid.y * MMA_TOK_TILE;
1607    if (row_base >= rows || tok_base >= num_tokens) return;
1608
1609    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1610    threadgroup float w_tile[MMA_ROW_TILE * 32];
1611    threadgroup float t_tile[MMA_TOK_TILE * 32];
1612
1613    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1614    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1615    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1616
1617    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1618
1619    uint blocks_per_row = cols / 32;
1620    uint row_bytes = blocks_per_row * 34;
1621
1622    for (uint blk = 0; blk < blocks_per_row; blk++) {
1623        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1624        // 512 floats / 128 threads = 4 floats per thread.
1625        {
1626            uint base = tid * 4;
1627            for (uint ii = 0; ii < 4; ii++) {
1628                uint idx = base + ii;
1629                uint r = idx / 32;
1630                uint k = idx % 32;
1631                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1632                float sc = float(*(device const half*)rp);
1633                int ival = int(*(device const int8_t*)(rp + 2 + k));
1634                w_tile[r * 32 + k] = float(ival) * sc;
1635            }
1636        }
1637
1638        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1639        {
1640            uint base = tid * 4;
1641            for (uint ii = 0; ii < 4; ii++) {
1642                uint idx = base + ii;
1643                uint m = idx / 32;
1644                uint k = idx % 32;
1645                uint tok = tok_base + m;
1646                t_tile[m * 32 + k] = (tok < num_tokens)
1647                    ? inputs[tok * cols + blk * 32 + k]
1648                    : 0.0f;
1649            }
1650        }
1651
1652        threadgroup_barrier(mem_flags::mem_threadgroup);
1653
1654        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1655        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1656        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1657        // C[m, r] += A[m, k] * B[k, r]
1658        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1659            simdgroup_matrix<float, 8, 8> A, B;
1660            simdgroup_load(A,
1661                t_tile + sg_tok_base * 32 + k_sub * 8,
1662                32,
1663                ulong2(0, 0),
1664                false);
1665            simdgroup_load(B,
1666                w_tile + sg_row_base * 32 + k_sub * 8,
1667                32,
1668                ulong2(0, 0),
1669                true);
1670            simdgroup_multiply_accumulate(C, A, B, C);
1671        }
1672
1673        threadgroup_barrier(mem_flags::mem_threadgroup);
1674    }
1675
1676    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1677    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1678    uint out_tok = tok_base + sg_tok_base;
1679    uint out_row = row_base + sg_row_base;
1680    bool full_tok = (out_tok + 8 <= num_tokens);
1681    if (full_tok) {
1682        // Fast path: entire 8×8 sub-tile is in-bounds.
1683        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1684    } else if (out_tok < num_tokens) {
1685        // Partial row at the last token tile: stage in per-simdgroup scratch
1686        // and scalar-copy the valid rows.
1687        threadgroup float scratch[4 * 64];
1688        simdgroup_store(C, scratch + simd_id * 64, 8);
1689        simdgroup_barrier(mem_flags::mem_threadgroup);
1690        uint lane = tid % 32;
1691        if (lane == 0) {
1692            uint valid = num_tokens - out_tok;  // 1..7
1693            for (uint m = 0; m < valid; m++) {
1694                device float* dst = outputs + (out_tok + m) * rows + out_row;
1695                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1696                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1697                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1698            }
1699        }
1700    }
1701}
1702
1703// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1704// Larger-tile variant of matmul_q8_mma for long-context prefill.
1705//
1706// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1707// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1708// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1709//
1710//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1711//     output sub-tiles (tok, row):
1712//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1713//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1714//
1715// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1716// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1717// kernel — and halves the number of threadgroups vs the 16×16 tile.
1718//
1719// Assumptions (verified in dispatch helper, fallback otherwise):
1720//   * cols % 32 == 0
1721//   * rows % 32 == 0
1722constant constexpr uint MMA32_TOK_TILE = 32;
1723constant constexpr uint MMA32_ROW_TILE = 32;
1724
1725kernel void matmul_q8_mma32(
1726    device const uchar* matrix   [[buffer(0)]],
1727    device const float* inputs   [[buffer(1)]],
1728    device float* outputs        [[buffer(2)]],
1729    constant uint& num_tokens    [[buffer(3)]],
1730    constant uint& rows          [[buffer(4)]],
1731    constant uint& cols          [[buffer(5)]],
1732    uint2 tgid [[threadgroup_position_in_grid]],
1733    uint tid [[thread_index_in_threadgroup]],
1734    uint simd_id [[simdgroup_index_in_threadgroup]])
1735{
1736    uint row_base = tgid.x * MMA32_ROW_TILE;
1737    uint tok_base = tgid.y * MMA32_TOK_TILE;
1738    if (row_base >= rows || tok_base >= num_tokens) return;
1739
1740    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1741    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1742    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1743
1744    uint sg_tok_idx  = simd_id / 2;      // 0..3
1745    uint sg_row_half = simd_id % 2;      // 0..1
1746    uint sg_tok_base = sg_tok_idx * 8;
1747    uint sg_row_base_a = sg_row_half * 16 + 0;
1748    uint sg_row_base_b = sg_row_half * 16 + 8;
1749
1750    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1751    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1752
1753    uint blocks_per_row = cols / 32;
1754    uint row_bytes = blocks_per_row * 34;
1755
1756    for (uint blk = 0; blk < blocks_per_row; blk++) {
1757        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1758        {
1759            uint base = tid * 4;
1760            for (uint ii = 0; ii < 4; ii++) {
1761                uint idx = base + ii;
1762                uint r = idx / 32;
1763                uint k = idx % 32;
1764                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1765                float sc = float(*(device const half*)rp);
1766                int ival = int(*(device const int8_t*)(rp + 2 + k));
1767                w_tile[r * 32 + k] = float(ival) * sc;
1768            }
1769        }
1770
1771        // Cooperative token tile load.
1772        {
1773            uint base = tid * 4;
1774            for (uint ii = 0; ii < 4; ii++) {
1775                uint idx = base + ii;
1776                uint m = idx / 32;
1777                uint k = idx % 32;
1778                uint tok = tok_base + m;
1779                t_tile[m * 32 + k] = (tok < num_tokens)
1780                    ? inputs[tok * cols + blk * 32 + k]
1781                    : 0.0f;
1782            }
1783        }
1784
1785        threadgroup_barrier(mem_flags::mem_threadgroup);
1786
1787        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1788        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1789            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1790            simdgroup_load(A,
1791                t_tile + sg_tok_base * 32 + k_sub * 8,
1792                32,
1793                ulong2(0, 0),
1794                false);
1795            simdgroup_load(B_a,
1796                w_tile + sg_row_base_a * 32 + k_sub * 8,
1797                32,
1798                ulong2(0, 0),
1799                true);
1800            simdgroup_load(B_b,
1801                w_tile + sg_row_base_b * 32 + k_sub * 8,
1802                32,
1803                ulong2(0, 0),
1804                true);
1805            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1806            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1807        }
1808
1809        threadgroup_barrier(mem_flags::mem_threadgroup);
1810    }
1811
1812    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1813    // (verified in dispatch), so full simdgroup_store is safe for the row
1814    // dimension; only the last token tile may be partial.
1815    uint out_tok = tok_base + sg_tok_base;
1816    uint out_row_a = row_base + sg_row_base_a;
1817    uint out_row_b = row_base + sg_row_base_b;
1818    bool full_tok = (out_tok + 8 <= num_tokens);
1819    if (full_tok) {
1820        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1821        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1822    } else if (out_tok < num_tokens) {
1823        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1824        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1825        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1826        simdgroup_barrier(mem_flags::mem_threadgroup);
1827        uint lane = tid % 32;
1828        if (lane == 0) {
1829            uint valid = num_tokens - out_tok;  // 1..7
1830            for (uint m = 0; m < valid; m++) {
1831                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1832                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1833                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1834                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1835                for (uint j = 0; j < 8; j++) {
1836                    dst_a[j] = src_a[j];
1837                    dst_b[j] = src_b[j];
1838                }
1839            }
1840        }
1841    }
1842}
1843
1844// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1845// FP16 threadgroup-tile variant of matmul_q8_mma32.
1846//
1847// Stores dequantized weights and token inputs as `half` in threadgroup
1848// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1849// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1850// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1851// representation preserves the full quantized dynamic range.  Token
1852// activations stay numerically safe because the subsequent
1853// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1854//
1855// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1856// accumulators each.  Primary win vs mma32 is occupancy at moderate
1857// prefill lengths where the GPU is wave-starved.
1858kernel void matmul_q8_mma32_h(
1859    device const uchar* matrix   [[buffer(0)]],
1860    device const float* inputs   [[buffer(1)]],
1861    device float* outputs        [[buffer(2)]],
1862    constant uint& num_tokens    [[buffer(3)]],
1863    constant uint& rows          [[buffer(4)]],
1864    constant uint& cols          [[buffer(5)]],
1865    uint2 tgid [[threadgroup_position_in_grid]],
1866    uint tid [[thread_index_in_threadgroup]],
1867    uint simd_id [[simdgroup_index_in_threadgroup]])
1868{
1869    uint row_base = tgid.x * MMA32_ROW_TILE;
1870    uint tok_base = tgid.y * MMA32_TOK_TILE;
1871    if (row_base >= rows || tok_base >= num_tokens) return;
1872
1873    // 32×32 half tiles — 2 KB each, 4 KB total.
1874    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1875    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1876
1877    uint sg_tok_idx  = simd_id / 2;
1878    uint sg_row_half = simd_id % 2;
1879    uint sg_tok_base = sg_tok_idx * 8;
1880    uint sg_row_base_a = sg_row_half * 16 + 0;
1881    uint sg_row_base_b = sg_row_half * 16 + 8;
1882
1883    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1884    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1885
1886    uint blocks_per_row = cols / 32;
1887    uint row_bytes = blocks_per_row * 34;
1888
1889    for (uint blk = 0; blk < blocks_per_row; blk++) {
1890        // Cooperative weight dequantization to FP16.
1891        {
1892            uint base = tid * 4;
1893            for (uint ii = 0; ii < 4; ii++) {
1894                uint idx = base + ii;
1895                uint r = idx / 32;
1896                uint k = idx % 32;
1897                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1898                float sc = float(*(device const half*)rp);
1899                int ival = int(*(device const int8_t*)(rp + 2 + k));
1900                w_tile[r * 32 + k] = half(float(ival) * sc);
1901            }
1902        }
1903
1904        // Cooperative token tile load (f32 → f16 narrowing).
1905        {
1906            uint base = tid * 4;
1907            for (uint ii = 0; ii < 4; ii++) {
1908                uint idx = base + ii;
1909                uint m = idx / 32;
1910                uint k = idx % 32;
1911                uint tok = tok_base + m;
1912                t_tile[m * 32 + k] = (tok < num_tokens)
1913                    ? half(inputs[tok * cols + blk * 32 + k])
1914                    : half(0);
1915            }
1916        }
1917
1918        threadgroup_barrier(mem_flags::mem_threadgroup);
1919
1920        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1921            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1922            simdgroup_load(A,
1923                t_tile + sg_tok_base * 32 + k_sub * 8,
1924                32,
1925                ulong2(0, 0),
1926                false);
1927            simdgroup_load(B_a,
1928                w_tile + sg_row_base_a * 32 + k_sub * 8,
1929                32,
1930                ulong2(0, 0),
1931                true);
1932            simdgroup_load(B_b,
1933                w_tile + sg_row_base_b * 32 + k_sub * 8,
1934                32,
1935                ulong2(0, 0),
1936                true);
1937            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1938            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1939        }
1940
1941        threadgroup_barrier(mem_flags::mem_threadgroup);
1942    }
1943
1944    uint out_tok = tok_base + sg_tok_base;
1945    uint out_row_a = row_base + sg_row_base_a;
1946    uint out_row_b = row_base + sg_row_base_b;
1947    bool full_tok = (out_tok + 8 <= num_tokens);
1948    if (full_tok) {
1949        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1950        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1951    } else if (out_tok < num_tokens) {
1952        threadgroup float scratch[8 * 2 * 64];
1953        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1954        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1955        simdgroup_barrier(mem_flags::mem_threadgroup);
1956        uint lane = tid % 32;
1957        if (lane == 0) {
1958            uint valid = num_tokens - out_tok;
1959            for (uint m = 0; m < valid; m++) {
1960                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1961                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1962                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1963                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1964                for (uint j = 0; j < 8; j++) {
1965                    dst_a[j] = src_a[j];
1966                    dst_b[j] = src_b[j];
1967                }
1968            }
1969        }
1970    }
1971}
1972
1973// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1974// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1975//
1976// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1977// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1978//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1979//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1980// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1981//
1982// Per K_sub iteration: load two A fragments and two B fragments, then run
1983// **four** MMA instructions reusing A_top with both B's and A_bot with
1984// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1985// 2-accumulator kernel and halves the thread count per threadgroup (128
1986// threads), which often improves occupancy on Apple GPUs where the
1987// concurrent-thread budget is the tighter limit than shared-memory size.
1988kernel void matmul_q8_mma32_h4(
1989    device const uchar* matrix   [[buffer(0)]],
1990    device const float* inputs   [[buffer(1)]],
1991    device float* outputs        [[buffer(2)]],
1992    constant uint& num_tokens    [[buffer(3)]],
1993    constant uint& rows          [[buffer(4)]],
1994    constant uint& cols          [[buffer(5)]],
1995    uint2 tgid [[threadgroup_position_in_grid]],
1996    uint tid [[thread_index_in_threadgroup]],
1997    uint simd_id [[simdgroup_index_in_threadgroup]])
1998{
1999    uint row_base = tgid.x * MMA32_ROW_TILE;
2000    uint tok_base = tgid.y * MMA32_TOK_TILE;
2001    if (row_base >= rows || tok_base >= num_tokens) return;
2002
2003    // 32×32 FP16 tiles, 4 KB total.
2004    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2005    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2006
2007    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
2008    uint sg_tok_q = simd_id / 2;   // 0..1
2009    uint sg_row_q = simd_id % 2;   // 0..1
2010    uint sg_tok_base = sg_tok_q * 16;
2011    uint sg_row_base = sg_row_q * 16;
2012
2013    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
2014    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
2015    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
2016    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
2017
2018    uint blocks_per_row = cols / 32;
2019    uint row_bytes = blocks_per_row * 34;
2020
2021    for (uint blk = 0; blk < blocks_per_row; blk++) {
2022        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
2023        {
2024            uint base = tid * 8;
2025            for (uint ii = 0; ii < 8; ii++) {
2026                uint idx = base + ii;
2027                uint r = idx / 32;
2028                uint k = idx % 32;
2029                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2030                float sc = float(*(device const half*)rp);
2031                int ival = int(*(device const int8_t*)(rp + 2 + k));
2032                w_tile[r * 32 + k] = half(float(ival) * sc);
2033            }
2034        }
2035
2036        // Cooperative token tile load.
2037        {
2038            uint base = tid * 8;
2039            for (uint ii = 0; ii < 8; ii++) {
2040                uint idx = base + ii;
2041                uint m = idx / 32;
2042                uint k = idx % 32;
2043                uint tok = tok_base + m;
2044                t_tile[m * 32 + k] = (tok < num_tokens)
2045                    ? half(inputs[tok * cols + blk * 32 + k])
2046                    : half(0);
2047            }
2048        }
2049
2050        threadgroup_barrier(mem_flags::mem_threadgroup);
2051
2052        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
2053        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2054            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2055            simdgroup_load(A_top,
2056                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2057                32,
2058                ulong2(0, 0),
2059                false);
2060            simdgroup_load(A_bot,
2061                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2062                32,
2063                ulong2(0, 0),
2064                false);
2065            simdgroup_load(B_lo,
2066                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2067                32,
2068                ulong2(0, 0),
2069                true);
2070            simdgroup_load(B_hi,
2071                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2072                32,
2073                ulong2(0, 0),
2074                true);
2075            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2076            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2077            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2078            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2079        }
2080
2081        threadgroup_barrier(mem_flags::mem_threadgroup);
2082    }
2083
2084    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
2085    uint out_tok_top = tok_base + sg_tok_base + 0;
2086    uint out_tok_bot = tok_base + sg_tok_base + 8;
2087    uint out_row_lo  = row_base + sg_row_base + 0;
2088    uint out_row_hi  = row_base + sg_row_base + 8;
2089    bool full = (out_tok_bot + 8 <= num_tokens);
2090    if (full) {
2091        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
2092        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
2093        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
2094        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
2095    } else {
2096        // Partial-token fallback via per-simdgroup scratch.
2097        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
2098        uint sg_base = simd_id * 256;
2099        simdgroup_store(C_00, scratch + sg_base +   0, 8);
2100        simdgroup_store(C_01, scratch + sg_base +  64, 8);
2101        simdgroup_store(C_10, scratch + sg_base + 128, 8);
2102        simdgroup_store(C_11, scratch + sg_base + 192, 8);
2103        simdgroup_barrier(mem_flags::mem_threadgroup);
2104        uint lane = tid % 32;
2105        if (lane == 0) {
2106            for (uint m = 0; m < 8; m++) {
2107                uint t_top = out_tok_top + m;
2108                if (t_top < num_tokens) {
2109                    device float* dst0 = outputs + t_top * rows + out_row_lo;
2110                    device float* dst1 = outputs + t_top * rows + out_row_hi;
2111                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
2112                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
2113                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2114                }
2115                uint t_bot = out_tok_bot + m;
2116                if (t_bot < num_tokens) {
2117                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2118                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2119                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2120                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2121                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2122                }
2123            }
2124        }
2125    }
2126}
2127
2128// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2129// All-half MMA variant of matmul_q8_mma32_h4.
2130//
2131// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2132// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2133// rate (dual-issue FMA), so if Q8_0 precision holds through half
2134// accumulation this kernel can double the effective FLOP throughput on
2135// matmul-bound prefill.
2136//
2137// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2138// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2139// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2140// precision on extreme values, but the inputs are already quantized so the
2141// per-product error floor is higher than the half-precision rounding error.
2142// We verify correctness on 135M / 1B / 3B before enabling.
2143kernel void matmul_q8_mma32_hh4(
2144    device const uchar* matrix   [[buffer(0)]],
2145    device const float* inputs   [[buffer(1)]],
2146    device float* outputs        [[buffer(2)]],
2147    constant uint& num_tokens    [[buffer(3)]],
2148    constant uint& rows          [[buffer(4)]],
2149    constant uint& cols          [[buffer(5)]],
2150    uint2 tgid [[threadgroup_position_in_grid]],
2151    uint tid [[thread_index_in_threadgroup]],
2152    uint simd_id [[simdgroup_index_in_threadgroup]])
2153{
2154    uint row_base = tgid.x * MMA32_ROW_TILE;
2155    uint tok_base = tgid.y * MMA32_TOK_TILE;
2156    if (row_base >= rows || tok_base >= num_tokens) return;
2157
2158    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2159    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2160
2161    uint sg_tok_q = simd_id / 2;
2162    uint sg_row_q = simd_id % 2;
2163    uint sg_tok_base = sg_tok_q * 16;
2164    uint sg_row_base = sg_row_q * 16;
2165
2166    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2167    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2168    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2169    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2170
2171    uint blocks_per_row = cols / 32;
2172    uint row_bytes = blocks_per_row * 34;
2173
2174    for (uint blk = 0; blk < blocks_per_row; blk++) {
2175        {
2176            uint base = tid * 8;
2177            for (uint ii = 0; ii < 8; ii++) {
2178                uint idx = base + ii;
2179                uint r = idx / 32;
2180                uint k = idx % 32;
2181                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2182                float sc = float(*(device const half*)rp);
2183                int ival = int(*(device const int8_t*)(rp + 2 + k));
2184                w_tile[r * 32 + k] = half(float(ival) * sc);
2185            }
2186        }
2187        {
2188            uint base = tid * 8;
2189            for (uint ii = 0; ii < 8; ii++) {
2190                uint idx = base + ii;
2191                uint m = idx / 32;
2192                uint k = idx % 32;
2193                uint tok = tok_base + m;
2194                t_tile[m * 32 + k] = (tok < num_tokens)
2195                    ? half(inputs[tok * cols + blk * 32 + k])
2196                    : half(0);
2197            }
2198        }
2199        threadgroup_barrier(mem_flags::mem_threadgroup);
2200
2201        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2202            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2203            simdgroup_load(A_top,
2204                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2205                32, ulong2(0, 0), false);
2206            simdgroup_load(A_bot,
2207                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2208                32, ulong2(0, 0), false);
2209            simdgroup_load(B_lo,
2210                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2211                32, ulong2(0, 0), true);
2212            simdgroup_load(B_hi,
2213                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2214                32, ulong2(0, 0), true);
2215            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2216            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2217            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2218            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2219        }
2220
2221        threadgroup_barrier(mem_flags::mem_threadgroup);
2222    }
2223
2224    // Store half accumulators via scratch (must widen to f32 for device output).
2225    uint out_tok_top = tok_base + sg_tok_base + 0;
2226    uint out_tok_bot = tok_base + sg_tok_base + 8;
2227    uint out_row_lo  = row_base + sg_row_base + 0;
2228    uint out_row_hi  = row_base + sg_row_base + 8;
2229
2230    threadgroup half scratch[4 * 4 * 64];
2231    uint sg_base = simd_id * 256;
2232    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2233    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2234    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2235    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2236    simdgroup_barrier(mem_flags::mem_threadgroup);
2237    uint lane = tid % 32;
2238    if (lane == 0) {
2239        for (uint m = 0; m < 8; m++) {
2240            uint t_top = out_tok_top + m;
2241            if (t_top < num_tokens) {
2242                device float* dst0 = outputs + t_top * rows + out_row_lo;
2243                device float* dst1 = outputs + t_top * rows + out_row_hi;
2244                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2245                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2246                for (uint j = 0; j < 8; j++) {
2247                    dst0[j] = float(src0[j]);
2248                    dst1[j] = float(src1[j]);
2249                }
2250            }
2251            uint t_bot = out_tok_bot + m;
2252            if (t_bot < num_tokens) {
2253                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2254                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2255                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2256                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2257                for (uint j = 0; j < 8; j++) {
2258                    dst2[j] = float(src2[j]);
2259                    dst3[j] = float(src3[j]);
2260                }
2261            }
2262        }
2263    }
2264}
2265
2266// ── add_bias_batch ─────────────────────────────────────────────────────
2267// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2268// Used for Qwen2 QKV bias after the fused qkv matmul.
2269//     out[token, i] += bias[i]    for i in 0..rows, token in 0..num_tokens
2270kernel void add_bias_batch(
2271    device float* out            [[buffer(0)]],  // [num_tokens, rows]
2272    device const float* bias     [[buffer(1)]],  // [rows]
2273    constant uint& num_tokens    [[buffer(2)]],
2274    constant uint& rows          [[buffer(3)]],
2275    uint id [[thread_position_in_grid]])
2276{
2277    uint total = num_tokens * rows;
2278    if (id >= total) return;
2279    uint i = id % rows;
2280    out[id] += bias[i];
2281}
2282
2283// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2284// Batched Q4_0 matrix-vector multiply for M input vectors.
2285// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2286kernel void matmul_vec_q4_batch(
2287    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2288    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2289    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2290    constant uint& num_tokens    [[buffer(3)]],  // M
2291    constant uint& rows          [[buffer(4)]],
2292    constant uint& cols          [[buffer(5)]],
2293    uint tgid [[threadgroup_position_in_grid]],
2294    uint tid [[thread_index_in_threadgroup]],
2295    uint simd_lane [[thread_index_in_simdgroup]],
2296    uint simd_id [[simdgroup_index_in_threadgroup]])
2297{
2298    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2299    uint token = tgid / row_tgs;
2300    uint tg_in_token = tgid % row_tgs;
2301    if (token >= num_tokens) return;
2302
2303    threadgroup float vec_tile[VEC_TILE_SIZE];
2304    device const float* input = inputs + token * cols;
2305    for (uint i = tid; i < cols; i += 256) {
2306        vec_tile[i] = input[i];
2307    }
2308    threadgroup_barrier(mem_flags::mem_threadgroup);
2309
2310    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2311    if (row_base >= rows) return;
2312
2313    uint blocks_per_row = cols / 32;
2314    uint row_bytes = blocks_per_row * 18;
2315
2316    device const uchar* r0 = matrix + row_base * row_bytes;
2317    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2318    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2319    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2320
2321    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2322
2323    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2324        uint bb = blk * 18;
2325        uint vb = blk * 32;
2326
2327        float sc0 = float(*(device const half*)(r0 + bb));
2328        float sc1 = float(*(device const half*)(r1 + bb));
2329        float sc2 = float(*(device const half*)(r2 + bb));
2330        float sc3 = float(*(device const half*)(r3 + bb));
2331
2332        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2333        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2334        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2335        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2336
2337        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2338        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2339        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2340        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2341        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2342        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2343        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2344        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2345
2346        float bd0=0, bd1=0, bd2=0, bd3=0;
2347        uchar4 b;
2348
2349        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2350        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2351        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2352        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2353
2354        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2355        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2356        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2357        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2358
2359        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2360        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2361        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2362        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2363
2364        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2365        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2366        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2367        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2368
2369        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2370    }
2371
2372    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2373    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2374
2375    device float* output = outputs + token * rows;
2376    if (simd_lane == 0) {
2377        if (row_base     < rows) output[row_base]     = sum0;
2378        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2379        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2380        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2381    }
2382}
2383
2384// ── copy_kv_batch ─────────────────────────────────────────────────────
2385// Copy K or V from a strided batch QKV buffer to the KV cache.
2386// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2387// dst layout: contiguous [max_seq, kv_dim] cache.
2388kernel void copy_kv_batch(
2389    device const float* src  [[buffer(0)]],  // batch QKV buffer (f32)
2390    device half* dst         [[buffer(1)]],  // KV cache (f16)
2391    constant uint& M         [[buffer(2)]],  // num tokens in batch
2392    constant uint& kv_dim    [[buffer(3)]],  // elements per KV vector
2393    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2394    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2395    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2396    uint id [[thread_position_in_grid]])
2397{
2398    uint total = M * kv_dim;
2399    if (id >= total) return;
2400    uint token = id / kv_dim;
2401    uint d = id % kv_dim;
2402    uint dst_off = (base_pos + token) * kv_dim + d;
2403    uint src_off = token * src_stride + src_offset + d;
2404    dst[dst_off] = half(src[src_off]);
2405}
2406
2407// ── attention_batch ───────────────────────────────────────────────────
2408// Batched causal attention for prefill. Processes M tokens in one dispatch.
2409// Each threadgroup handles one (token, head) pair.
2410// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2411// Causal masking: token i can only attend to positions 0..base_pos+i.
2412kernel void attention_batch(
2413    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2414    device const half*  k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim] f16
2415    device const half*  v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim] f16
2416    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2417    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2418    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2419    constant uint& num_heads         [[buffer(6)]],
2420    constant uint& num_kv_heads      [[buffer(7)]],
2421    constant uint& head_dim          [[buffer(8)]],
2422    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2423    uint tgid [[threadgroup_position_in_grid]],
2424    uint tid [[thread_index_in_threadgroup]],
2425    uint simd_lane [[thread_index_in_simdgroup]],
2426    uint simd_id [[simdgroup_index_in_threadgroup]])
2427{
2428    // Grid: M * num_heads threadgroups
2429    uint token_idx = tgid / num_heads;
2430    uint head = tgid % num_heads;
2431    if (token_idx >= M) return;
2432
2433    uint kv_head = head / (num_heads / num_kv_heads);
2434    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2435
2436    // Q offset uses strided layout (from batch QKV buffer)
2437    uint q_off = token_idx * q_stride + head * head_dim;
2438    // Output is contiguous [M, num_heads * head_dim]
2439    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2440
2441    // Shared memory for attention scores — sized to the effective max_seq_len
2442    // (4096 for all supported models) so long-context attention doesn't overflow.
2443    threadgroup float scores[ATTN_SCORES_SIZE];
2444
2445    // Step 1: Q * K^T with simdgroup reduction
2446    for (uint s = simd_id; s < seq_len; s += 8) {
2447        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2448        float dot = 0.0;
2449        for (uint d = simd_lane; d < head_dim; d += 32) {
2450            dot += q_batch[q_off + d] * k_cache[k_off + d];
2451        }
2452        dot = simd_sum(dot);
2453        if (simd_lane == 0) {
2454            scores[s] = dot * fast::rsqrt(float(head_dim));
2455        }
2456    }
2457    threadgroup_barrier(mem_flags::mem_threadgroup);
2458
2459    // Step 2: Softmax (cooperative)
2460    float local_max = -INFINITY;
2461    for (uint s = tid; s < seq_len; s += 256) {
2462        local_max = max(local_max, scores[s]);
2463    }
2464    local_max = simd_max(local_max);
2465    threadgroup float shared_max[8];
2466    if (simd_lane == 0) shared_max[simd_id] = local_max;
2467    threadgroup_barrier(mem_flags::mem_threadgroup);
2468    if (tid == 0) {
2469        float m = shared_max[0];
2470        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2471        shared_max[0] = m;
2472    }
2473    threadgroup_barrier(mem_flags::mem_threadgroup);
2474    float max_val = shared_max[0];
2475
2476    float local_sum = 0.0;
2477    for (uint s = tid; s < seq_len; s += 256) {
2478        scores[s] = fast::exp(scores[s] - max_val);
2479        local_sum += scores[s];
2480    }
2481    local_sum = simd_sum(local_sum);
2482    threadgroup float shared_sum[8];
2483    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2484    threadgroup_barrier(mem_flags::mem_threadgroup);
2485    if (tid == 0) {
2486        float total = 0.0;
2487        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2488        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2489    }
2490    threadgroup_barrier(mem_flags::mem_threadgroup);
2491    float inv_sum = shared_sum[0];
2492    for (uint s = tid; s < seq_len; s += 256) {
2493        scores[s] *= inv_sum;
2494    }
2495    threadgroup_barrier(mem_flags::mem_threadgroup);
2496
2497    // Step 3: scores * V using float4 vectorized loads
2498    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2499    // This is much better than the scalar version where only 64 of 256 threads are active.
2500    uint v_stride = num_kv_heads * head_dim;
2501    uint head_dim4 = head_dim / 4;
2502    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2503        uint d = d4 * 4;
2504        float4 acc = float4(0.0);
2505        uint v_base = kv_head * head_dim + d;
2506        uint seq_len4 = seq_len & ~3u;
2507        for (uint s = 0; s < seq_len4; s += 4) {
2508            float sc0 = scores[s];
2509            float sc1 = scores[s + 1];
2510            float sc2 = scores[s + 2];
2511            float sc3 = scores[s + 3];
2512            acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
2513                 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
2514                 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
2515                 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
2516        }
2517        for (uint s = seq_len4; s < seq_len; s++) {
2518            acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
2519        }
2520        *(device float4*)(output_batch + out_off + d) = acc;
2521    }
2522    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2523    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2524        float acc = 0.0;
2525        uint v_base = kv_head * head_dim + d;
2526        for (uint s = 0; s < seq_len; s++) {
2527            acc += scores[s] * v_cache[s * v_stride + v_base];
2528        }
2529        output_batch[out_off + d] = acc;
2530    }
2531}
2532
2533// ── attention_flash_batch ─────────────────────────────────────────────
2534// Streaming attention with online softmax.  Same grid as attention_batch
2535// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2536// matrix is never materialized.  K/V positions are processed in a tile of
2537// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2538// the standard flash-attention recurrence:
2539//
2540//     m_new   = max(m_old, tile_max)
2541//     alpha   = exp(m_old - m_new)
2542//     l_new   = alpha * l_old + sum(exp(S - m_new))
2543//     O_new   = alpha * O_old + sum(exp(S - m_new) * V)
2544//     O_final = O / l
2545//
2546// This removes the `scores[2048]` cap in attention_batch (which silently
2547// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2548// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2549//
2550// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2551constant constexpr uint FLASH_K_TILE = 32;
2552constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2553
2554kernel void attention_flash_batch(
2555    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2556    device const half*  k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim] f16
2557    device const half*  v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim] f16
2558    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2559    constant uint& M                 [[buffer(4)]],
2560    constant uint& base_pos          [[buffer(5)]],
2561    constant uint& num_heads         [[buffer(6)]],
2562    constant uint& num_kv_heads      [[buffer(7)]],
2563    constant uint& head_dim          [[buffer(8)]],
2564    constant uint& q_stride          [[buffer(9)]],
2565    uint tgid [[threadgroup_position_in_grid]],
2566    uint tid [[thread_index_in_threadgroup]],
2567    uint simd_lane [[thread_index_in_simdgroup]],
2568    uint simd_id [[simdgroup_index_in_threadgroup]])
2569{
2570    uint token_idx = tgid / num_heads;
2571    uint head = tgid % num_heads;
2572    if (token_idx >= M) return;
2573
2574    uint kv_head = head / (num_heads / num_kv_heads);
2575    uint seq_len = base_pos + token_idx + 1;  // causal: attend to [0, base_pos + token_idx]
2576    uint q_off = token_idx * q_stride + head * head_dim;
2577    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2578
2579    // Threadgroup state:
2580    //   q_sh:      Q vector for this (token, head), loaded once
2581    //   o_sh:      running output vector, updated each K tile
2582    //   scores_sh: scores for the current K tile only
2583    //   stats:     [running max, running sum]  (see flash-attention recurrence)
2584    threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2585    threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2586    threadgroup float scores_sh[FLASH_K_TILE];
2587    threadgroup float stats[2];
2588    threadgroup float sg_scratch[8];  // simdgroup-level reduction buffer
2589
2590    // --- Load Q (one row) and zero the running O ---
2591    for (uint d = tid; d < head_dim; d += 256) {
2592        q_sh[d] = q_batch[q_off + d];
2593        o_sh[d] = 0.0f;
2594    }
2595    if (tid == 0) {
2596        stats[0] = -INFINITY;
2597        stats[1] = 0.0f;
2598    }
2599    threadgroup_barrier(mem_flags::mem_threadgroup);
2600
2601    float scale = fast::rsqrt(float(head_dim));
2602    uint v_stride = num_kv_heads * head_dim;
2603    uint v_base = kv_head * head_dim;
2604
2605    // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2606    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2607        uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2608
2609        // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2610        // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2611        for (uint ti = simd_id; ti < tile_n; ti += 8) {
2612            uint k_off = (kv_base + ti) * v_stride + v_base;  // same layout as V stride
2613            float dot = 0.0f;
2614            for (uint d = simd_lane; d < head_dim; d += 32) {
2615                dot += q_sh[d] * k_cache[k_off + d];
2616            }
2617            dot = simd_sum(dot);
2618            if (simd_lane == 0) {
2619                scores_sh[ti] = dot * scale;
2620            }
2621        }
2622        threadgroup_barrier(mem_flags::mem_threadgroup);
2623
2624        // [2] Tile max via cooperative reduction.
2625        float local_max = -INFINITY;
2626        for (uint s = tid; s < tile_n; s += 256) {
2627            local_max = max(local_max, scores_sh[s]);
2628        }
2629        local_max = simd_max(local_max);
2630        if (simd_lane == 0) {
2631            sg_scratch[simd_id] = local_max;
2632        }
2633        threadgroup_barrier(mem_flags::mem_threadgroup);
2634        // [3] Merge with running max, compute alpha, rescale running l.
2635        float m_new;
2636        float alpha;
2637        if (tid == 0) {
2638            float tile_max = sg_scratch[0];
2639            for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2640            float m_old = stats[0];
2641            m_new = max(m_old, tile_max);
2642            // First iteration: m_old = -inf → alpha = 0 (reset O).
2643            alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2644            stats[0] = m_new;
2645            stats[1] *= alpha;
2646            // Broadcast via sg_scratch.
2647            sg_scratch[0] = alpha;
2648            sg_scratch[1] = m_new;
2649        }
2650        threadgroup_barrier(mem_flags::mem_threadgroup);
2651        alpha = sg_scratch[0];
2652        m_new = sg_scratch[1];
2653
2654        // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2655        for (uint d = tid; d < head_dim; d += 256) {
2656            o_sh[d] *= alpha;
2657        }
2658        for (uint s = tid; s < tile_n; s += 256) {
2659            scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2660        }
2661        threadgroup_barrier(mem_flags::mem_threadgroup);
2662
2663        // [5] Tile sum → update running l.
2664        float local_sum = 0.0f;
2665        for (uint s = tid; s < tile_n; s += 256) {
2666            local_sum += scores_sh[s];
2667        }
2668        local_sum = simd_sum(local_sum);
2669        if (simd_lane == 0) {
2670            sg_scratch[simd_id] = local_sum;
2671        }
2672        threadgroup_barrier(mem_flags::mem_threadgroup);
2673        if (tid == 0) {
2674            float tile_sum = 0.0f;
2675            for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2676            stats[1] += tile_sum;
2677        }
2678        threadgroup_barrier(mem_flags::mem_threadgroup);
2679
2680        // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2681        for (uint d = tid; d < head_dim; d += 256) {
2682            float acc = 0.0f;
2683            for (uint s = 0; s < tile_n; s++) {
2684                acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2685            }
2686            o_sh[d] += acc;
2687        }
2688        threadgroup_barrier(mem_flags::mem_threadgroup);
2689    }
2690
2691    // --- Normalize and write output ---
2692    float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2693    for (uint d = tid; d < head_dim; d += 256) {
2694        output_batch[out_off + d] = o_sh[d] * inv_l;
2695    }
2696}
2697
2698// ── attention_mma_flash_batch ─────────────────────────────────────────
2699// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2700// both Q·K^T and P·V.  Processes Q_BLOCK=8 tokens of one head per
2701// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2702// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2703// arithmetic.
2704//
2705// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2706// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2707// to attention_batch / attention_flash_batch otherwise.
2708//
2709// Online softmax recurrence is identical to attention_flash_batch but
2710// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2711constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2712constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2713constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2714
2715kernel void attention_mma_flash_batch(
2716    device const float* q_batch      [[buffer(0)]],
2717    device const half*  k_cache      [[buffer(1)]],
2718    device const half*  v_cache      [[buffer(2)]],
2719    device float* output_batch       [[buffer(3)]],
2720    constant uint& M                 [[buffer(4)]],
2721    constant uint& base_pos          [[buffer(5)]],
2722    constant uint& num_heads         [[buffer(6)]],
2723    constant uint& num_kv_heads      [[buffer(7)]],
2724    constant uint& head_dim          [[buffer(8)]],
2725    constant uint& q_stride          [[buffer(9)]],
2726    uint2 tgid [[threadgroup_position_in_grid]],
2727    uint tid [[thread_index_in_threadgroup]],
2728    uint simd_lane [[thread_index_in_simdgroup]],
2729    uint simd_id [[simdgroup_index_in_threadgroup]])
2730{
2731    uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2732    uint head = tgid.y;
2733    if (q_block_start >= M) return;
2734    uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2735
2736    uint kv_head = head / (num_heads / num_kv_heads);
2737    // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2738    // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2739    uint seq_len = base_pos + q_block_start + q_valid;
2740    float scale = fast::rsqrt(float(head_dim));
2741
2742    uint kv_stride = num_kv_heads * head_dim;
2743    uint kv_base_off = kv_head * head_dim;
2744
2745    // ── Threadgroup memory ──
2746    // q_sh:  [Q_BLOCK, head_dim] half  — Q tile, loaded once
2747    // k_sh:  [K_BLOCK, head_dim] half  — K tile, refreshed per kv_base iter
2748    // v_sh:  [K_BLOCK, head_dim] half  — V tile, refreshed per kv_base iter
2749    // s_sh:  [Q_BLOCK, K_BLOCK] float  — raw Q·K^T scores, then scaled+masked
2750    // p_sh:  [Q_BLOCK, K_BLOCK] half   — softmax probabilities (for P·V MMA)
2751    // o_sh:  [Q_BLOCK, head_dim] float — running output accumulator
2752    // m_sh:  [Q_BLOCK] float           — running max per Q row
2753    // l_sh:  [Q_BLOCK] float           — running softmax denominator per Q row
2754    // scratch: 4*Q_BLOCK floats        — per-row reduction scratch
2755    threadgroup half  q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2756    threadgroup half  k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2757    threadgroup half  v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2758    threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2759    threadgroup half  p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2760    threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2761    threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2762    threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2763    threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2764
2765    // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2766    uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2767    for (uint i = tid; i < qblock_elems; i += 128) {
2768        uint q = i / head_dim;
2769        uint d = i % head_dim;
2770        if (q < q_valid) {
2771            uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2772            q_sh[q * head_dim + d] = half(q_batch[q_off]);
2773        } else {
2774            q_sh[q * head_dim + d] = half(0);
2775        }
2776        o_sh[q * head_dim + d] = 0.0f;
2777    }
2778    if (tid < FLASH_MMA_Q_BLOCK) {
2779        m_sh[tid] = -INFINITY;
2780        l_sh[tid] = 0.0f;
2781    }
2782    threadgroup_barrier(mem_flags::mem_threadgroup);
2783
2784    // ── Stream K/V in K_BLOCK chunks ──
2785    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2786        uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2787
2788        // Load K and V tile into TG memory (as half).
2789        uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2790        for (uint i = tid; i < kv_tile_elems; i += 128) {
2791            uint k_pos = i / head_dim;
2792            uint d = i % head_dim;
2793            if (k_pos < tile_n) {
2794                uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2795                k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2796                v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2797            } else {
2798                k_sh[k_pos * head_dim + d] = half(0);
2799                v_sh[k_pos * head_dim + d] = half(0);
2800            }
2801        }
2802        threadgroup_barrier(mem_flags::mem_threadgroup);
2803
2804        // ── Phase 1: S = Q @ K^T via MMA ──
2805        // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2806        // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2807        // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2808        {
2809            simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2810            uint dim_chunks = head_dim / 8;
2811            for (uint dc = 0; dc < dim_chunks; dc++) {
2812                simdgroup_matrix<half, 8, 8> A, B;
2813                // A = Q[0:8, dc*8 : dc*8+8]  (rows of Q, no transpose)
2814                simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2815                // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2816                // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2817                // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2818                // transpose=true, which places it in the register as K^T of that sub-block.
2819                simdgroup_load(B,
2820                    k_sh + (simd_id * 8) * head_dim + dc * 8,
2821                    head_dim, ulong2(0, 0), true);
2822                simdgroup_multiply_accumulate(C, A, B, C);
2823            }
2824            // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2825            simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2826        }
2827        threadgroup_barrier(mem_flags::mem_threadgroup);
2828
2829        // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2830        // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2831        uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2832        for (uint i = tid; i < s_elems; i += 128) {
2833            uint q = i / FLASH_MMA_K_BLOCK;
2834            uint k = i % FLASH_MMA_K_BLOCK;
2835            uint global_q = q_block_start + q;
2836            uint global_kv = kv_base + k;
2837            bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2838            s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2839        }
2840        threadgroup_barrier(mem_flags::mem_threadgroup);
2841
2842        // ── Phase 2b: per-row max via simdgroup reduction ──
2843        // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2844        // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2845        {
2846            uint row_base = simd_id * 2;
2847            for (uint qr = 0; qr < 2; qr++) {
2848                uint q = row_base + qr;
2849                float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2850                float row_max = simd_max(my);
2851                if (simd_lane == 0) {
2852                    scratch[q] = row_max;  // tile_max[q]
2853                }
2854            }
2855        }
2856        threadgroup_barrier(mem_flags::mem_threadgroup);
2857
2858        // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2859        if (tid < FLASH_MMA_Q_BLOCK) {
2860            uint q = tid;
2861            float m_old = m_sh[q];
2862            float tile_max = scratch[q];
2863            float m_new = max(m_old, tile_max);
2864            float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2865            m_sh[q] = m_new;
2866            l_sh[q] = l_sh[q] * alpha;
2867            // scratch[q]               = m_new   (for phase 2d)
2868            // scratch[Q_BLOCK + q]     = alpha   (for phase 3)
2869            scratch[q] = m_new;
2870            scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2871        }
2872        threadgroup_barrier(mem_flags::mem_threadgroup);
2873
2874        // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2875        {
2876            uint row_base = simd_id * 2;
2877            for (uint qr = 0; qr < 2; qr++) {
2878                uint q = row_base + qr;
2879                float m_new = scratch[q];
2880                float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2881                p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2882                float row_sum = simd_sum(p);
2883                if (simd_lane == 0) {
2884                    scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2885                }
2886            }
2887        }
2888        threadgroup_barrier(mem_flags::mem_threadgroup);
2889
2890        // ── Phase 2e: l_sh += tile_sum ──
2891        if (tid < FLASH_MMA_Q_BLOCK) {
2892            uint q = tid;
2893            l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2894        }
2895
2896        // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2897        threadgroup_barrier(mem_flags::mem_threadgroup);
2898        for (uint i = tid; i < qblock_elems; i += 128) {
2899            uint q = i / head_dim;
2900            float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2901            o_sh[i] *= alpha;
2902        }
2903        threadgroup_barrier(mem_flags::mem_threadgroup);
2904
2905        // ── Phase 4: O += P @ V via MMA ──
2906        // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2907        // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2908        // For head_dim=64: 16 dims/sg = 2 tiles.  head_dim=128: 32 dims/sg = 4 tiles.
2909        {
2910            uint dims_per_sg = head_dim / 4;       // 16 or 32
2911            uint tiles_per_sg = dims_per_sg / 8;   // 2 or 4
2912            uint sg_d_base = simd_id * dims_per_sg;
2913            for (uint t = 0; t < tiles_per_sg; t++) {
2914                uint d_base = sg_d_base + t * 8;
2915                simdgroup_matrix<float, 8, 8> O_acc;
2916                simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2917                uint k_chunks = FLASH_MMA_K_BLOCK / 8;  // 4
2918                for (uint kc = 0; kc < k_chunks; kc++) {
2919                    simdgroup_matrix<half, 8, 8> A, B;
2920                    simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2921                                   ulong2(0, 0), false);
2922                    simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2923                                   ulong2(0, 0), false);
2924                    simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2925                }
2926                simdgroup_store(O_acc, o_sh + d_base, head_dim);
2927            }
2928        }
2929        threadgroup_barrier(mem_flags::mem_threadgroup);
2930    }
2931
2932    // ── Finalize: O /= l, write to output ──
2933    for (uint i = tid; i < qblock_elems; i += 128) {
2934        uint q = i / head_dim;
2935        uint d = i % head_dim;
2936        if (q < q_valid) {
2937            float l = l_sh[q];
2938            float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2939            uint token_idx = q_block_start + q;
2940            uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2941            output_batch[out_off] = o_sh[i] * inv_l;
2942        }
2943    }
2944}
2945
2946// ── rope_qk_batch ─────────────────────────────────────────────────────
2947// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2948// launch + memory barrier per layer. Both Q and K live in the same
2949// qkv_data buffer at different offsets within each token's row.
2950// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2951kernel void rope_qk_batch(
2952    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2953    constant uint& M                 [[buffer(1)]],   // num tokens
2954    constant uint& base_pos          [[buffer(2)]],   // starting position
2955    constant uint& num_q_heads       [[buffer(3)]],
2956    constant uint& num_kv_heads      [[buffer(4)]],
2957    constant uint& head_dim          [[buffer(5)]],
2958    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2959    constant float& theta            [[buffer(7)]],
2960    uint id [[thread_position_in_grid]])
2961{
2962    uint half_dim = head_dim / 2;
2963    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2964    uint token = id / total_pairs;
2965    uint pair = id % total_pairs;
2966    if (token >= M) return;
2967
2968    uint pos = base_pos + token;
2969    uint q_pairs = num_q_heads * half_dim;
2970
2971    uint h, i, offset;
2972    if (pair < q_pairs) {
2973        // Q head
2974        h = pair / half_dim;
2975        i = pair % half_dim;
2976        offset = token * qkv_stride + h * head_dim + i * 2;
2977    } else {
2978        // K head
2979        uint kp = pair - q_pairs;
2980        h = kp / half_dim;
2981        i = kp % half_dim;
2982        uint k_start = num_q_heads * head_dim;
2983        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2984    }
2985
2986    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2987    float angle = float(pos) * freq;
2988    float cos_val = cos(angle);
2989    float sin_val = sin(angle);
2990
2991    float x0 = qkv_data[offset];
2992    float x1 = qkv_data[offset + 1];
2993    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2994    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2995}
2996
2997// ── copy_kv_both_batch ────────────────────────────────────────────────
2998// Fused K+V cache copy in a single dispatch: copies both K and V from
2999// the strided batch QKV buffer to their respective KV cache buffers.
3000// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
3001kernel void copy_kv_both_batch(
3002    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride] f32
3003    device half* k_dst         [[buffer(1)]],  // K cache [max_seq, kv_dim] f16
3004    device half* v_dst         [[buffer(2)]],  // V cache [max_seq, kv_dim] f16
3005    constant uint& M           [[buffer(3)]],  // num tokens in batch
3006    constant uint& kv_dim      [[buffer(4)]],  // elements per KV vector
3007    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
3008    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
3009    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
3010    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
3011    uint id [[thread_position_in_grid]])
3012{
3013    // Total elements = M * kv_dim * 2 (K + V)
3014    uint total_kv = M * kv_dim;
3015    if (id >= total_kv * 2) return;
3016
3017    uint is_v = id / total_kv;        // 0 = K, 1 = V
3018    uint local_id = id % total_kv;
3019    uint token = local_id / kv_dim;
3020    uint d = local_id % kv_dim;
3021
3022    uint dst_off = (base_pos + token) * kv_dim + d;
3023    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
3024
3025    if (is_v) {
3026        v_dst[dst_off] = half(src[src_off]);
3027    } else {
3028        k_dst[dst_off] = half(src[src_off]);
3029    }
3030}
3031"#
3032    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
3033    .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
3034}
3035
3036// ---------------------------------------------------------------------------
3037// model.rs generation
3038// ---------------------------------------------------------------------------
3039
3040fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
3041    let mut code = String::with_capacity(48 * 1024);
3042    emit_model_header(&mut code, config)?;
3043    emit_metal_model_struct(&mut code, config)?;
3044    emit_layer_buffers_struct(&mut code, config)?;
3045    emit_metal_model_impl(&mut code, config)?;
3046    emit_helper_functions(&mut code)?;
3047    Ok(code)
3048}
3049
3050fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3051    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
3052    writeln!(
3053        code,
3054        "//! Model: {} ({} layers, hidden={})",
3055        config.architecture, config.num_layers, config.hidden_size
3056    )?;
3057    writeln!(code, "//!")?;
3058    writeln!(
3059        code,
3060        "//! Uses native Metal compute pipelines via the metal crate."
3061    )?;
3062    writeln!(code)?;
3063    writeln!(code, "#![allow(dead_code)]")?;
3064    writeln!(code)?;
3065    writeln!(code, "use metal::*;")?;
3066    writeln!(code, "#[allow(unused_imports)]")?;
3067    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
3068    writeln!(code, "use std::mem;")?;
3069    writeln!(code)?;
3070
3071    // Model constants
3072    writeln!(
3073        code,
3074        "// ── Model constants ──────────────────────────────────"
3075    )?;
3076    writeln!(
3077        code,
3078        "pub const HIDDEN_SIZE: usize = {};",
3079        config.hidden_size
3080    )?;
3081    writeln!(
3082        code,
3083        "pub const INTERMEDIATE_SIZE: usize = {};",
3084        config.intermediate_size
3085    )?;
3086    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
3087    writeln!(
3088        code,
3089        "pub const NUM_HEADS: usize = {};",
3090        config.num_attention_heads
3091    )?;
3092    writeln!(
3093        code,
3094        "pub const NUM_KV_HEADS: usize = {};",
3095        config.num_kv_heads
3096    )?;
3097    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
3098    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
3099    let effective_seq_len = config.max_seq_len.min(4096);
3100    writeln!(
3101        code,
3102        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
3103        effective_seq_len, config.max_seq_len
3104    )?;
3105    writeln!(
3106        code,
3107        "pub const RMS_NORM_EPS: f32 = {:e};",
3108        config.rms_norm_eps
3109    )?;
3110    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3111    writeln!(
3112        code,
3113        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3114    )?;
3115    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3116    writeln!(code)?;
3117
3118    Ok(())
3119}
3120
3121fn emit_metal_model_struct(
3122    code: &mut String,
3123    config: &ModelConfig,
3124) -> Result<(), MetalCodegenError> {
3125    writeln!(
3126        code,
3127        "// ── MetalModel ──────────────────────────────────────────"
3128    )?;
3129    writeln!(code)?;
3130    writeln!(
3131        code,
3132        "/// Metal-accelerated transformer model for Apple Silicon."
3133    )?;
3134    writeln!(code, "///")?;
3135    writeln!(
3136        code,
3137        "/// Uses unified memory for zero-copy weight access and native Metal"
3138    )?;
3139    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3140    writeln!(code, "pub struct MetalModel {{")?;
3141    writeln!(code, "    device: Device,")?;
3142    writeln!(code, "    queue: CommandQueue,")?;
3143    writeln!(code)?;
3144    writeln!(code, "    // ── Compute pipelines ──")?;
3145    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
3146    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
3147    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
3148    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
3149    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
3150    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
3151    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
3152    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
3153    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
3154    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
3155    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
3156    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
3157    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
3158    writeln!(
3159        code,
3160        "    copy_f32_to_f16_offset_pipeline: ComputePipelineState,"
3161    )?;
3162    writeln!(code)?;
3163    writeln!(code, "    // ── Batched prefill pipelines ──")?;
3164    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
3165    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
3166    writeln!(
3167        code,
3168        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3169    )?;
3170    writeln!(code, "    matmul_q8_mma_pipeline: ComputePipelineState,")?;
3171    writeln!(code, "    matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3172    writeln!(
3173        code,
3174        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3175    )?;
3176    writeln!(
3177        code,
3178        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3179    )?;
3180    writeln!(
3181        code,
3182        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3183    )?;
3184    if config.qkv_bias {
3185        writeln!(code, "    add_bias_batch_pipeline: ComputePipelineState,")?;
3186    }
3187    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
3188    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
3189    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
3190    writeln!(
3191        code,
3192        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
3193    )?;
3194    writeln!(
3195        code,
3196        "    add_inplace_batch_pipeline: ComputePipelineState,"
3197    )?;
3198    writeln!(
3199        code,
3200        "    copy_embedding_batch_pipeline: ComputePipelineState,\n    scale_buffer_pipeline: ComputePipelineState,"
3201    )?;
3202    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
3203    writeln!(
3204        code,
3205        "    attention_flash_batch_pipeline: ComputePipelineState,"
3206    )?;
3207    writeln!(
3208        code,
3209        "    attention_mma_flash_batch_pipeline: ComputePipelineState,"
3210    )?;
3211    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
3212    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
3213    writeln!(
3214        code,
3215        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
3216    )?;
3217    writeln!(code)?;
3218    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
3219    writeln!(
3220        code,
3221        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3222    )?;
3223    writeln!(code, "    embed_buf: Buffer,")?;
3224    writeln!(code)?;
3225    writeln!(code, "    /// Per-layer weight buffers")?;
3226    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
3227    writeln!(code)?;
3228    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
3229    writeln!(code, "    norm_buf: Buffer,")?;
3230    writeln!(code)?;
3231    writeln!(
3232        code,
3233        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3234    )?;
3235    writeln!(code, "    lm_head_buf: Buffer,")?;
3236    writeln!(code)?;
3237    writeln!(
3238        code,
3239        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
3240    )?;
3241    writeln!(code, "    hidden_buf: Buffer,")?;
3242    writeln!(code, "    residual_buf: Buffer,")?;
3243    writeln!(code, "    normed_buf: Buffer,")?;
3244    writeln!(
3245        code,
3246        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3247    )?;
3248    writeln!(code, "    qkv_buf: Buffer,")?;
3249    writeln!(code, "    attn_out_buf: Buffer,")?;
3250    writeln!(code, "    attn_proj_buf: Buffer,")?;
3251    writeln!(
3252        code,
3253        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3254    )?;
3255    writeln!(code, "    gate_up_buf: Buffer,")?;
3256    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
3257    writeln!(code, "    ffn_out_buf: Buffer,")?;
3258    writeln!(code, "    add_tmp_buf: Buffer,")?;
3259    writeln!(code, "    logits_buf: Buffer,")?;
3260    writeln!(code)?;
3261    writeln!(code, "    // ── Batched prefill working buffers ──")?;
3262    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3263    writeln!(code, "    batch_hidden_buf: Buffer,")?;
3264    writeln!(
3265        code,
3266        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3267    )?;
3268    writeln!(code, "    batch_residual_buf: Buffer,")?;
3269    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3270    writeln!(code, "    batch_qkv_buf: Buffer,")?;
3271    writeln!(
3272        code,
3273        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3274    )?;
3275    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
3276    writeln!(
3277        code,
3278        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3279    )?;
3280    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
3281    writeln!(
3282        code,
3283        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3284    )?;
3285    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
3286    writeln!(
3287        code,
3288        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3289    )?;
3290    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
3291    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3292    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
3293    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
3294    writeln!(code, "    batch_tokens_buf: Buffer,")?;
3295    writeln!(code, "    /// Positions buffer for batch RoPE")?;
3296    writeln!(code, "    batch_positions_buf: Buffer,")?;
3297    writeln!(code)?;
3298    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
3299    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
3300    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
3301    writeln!(code)?;
3302    writeln!(code, "    // ── Inference state ──")?;
3303    writeln!(code, "    pos: usize,")?;
3304    writeln!(code)?;
3305    writeln!(
3306        code,
3307        "    /// Previous command buffer for double-buffered prefill."
3308    )?;
3309    writeln!(
3310        code,
3311        "    /// While the GPU executes token N, the CPU can encode token N+1."
3312    )?;
3313    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
3314    writeln!(code, "}}")?;
3315    writeln!(code)?;
3316
3317    Ok(())
3318}
3319
3320fn emit_layer_buffers_struct(
3321    code: &mut String,
3322    config: &ModelConfig,
3323) -> Result<(), MetalCodegenError> {
3324    writeln!(
3325        code,
3326        "/// Per-layer weight buffers for attention and FFN projections."
3327    )?;
3328    writeln!(code, "struct LayerBuffers {{")?;
3329    writeln!(code, "    attn_norm: Buffer,")?;
3330    writeln!(
3331        code,
3332        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3333    )?;
3334    writeln!(code, "    qkv_weight: Buffer,")?;
3335    if config.qkv_bias {
3336        writeln!(
3337            code,
3338            "    /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3339        )?;
3340        writeln!(code, "    qkv_bias: Buffer,")?;
3341    }
3342    writeln!(code, "    o_weight: Buffer,")?;
3343    writeln!(code, "    ffn_norm: Buffer,")?;
3344    writeln!(
3345        code,
3346        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3347    )?;
3348    writeln!(code, "    gate_up_weight: Buffer,")?;
3349    writeln!(code, "    down_weight: Buffer,")?;
3350    writeln!(code, "}}")?;
3351    writeln!(code)?;
3352
3353    Ok(())
3354}
3355
3356fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3357    let hidden = config.hidden_size;
3358    let intermediate = config.intermediate_size;
3359    let _num_layers = config.num_layers;
3360    let _num_heads = config.num_attention_heads;
3361    let num_kv_heads = config.num_kv_heads;
3362    let head_dim = config.head_dim;
3363    let vocab = config.vocab_size;
3364    let effective_seq_len = config.max_seq_len.min(4096);
3365    let is_q8 = config.dtype == DType::Q8_0;
3366    let is_q4 = config.dtype == DType::Q4_0;
3367    let kv_dim = num_kv_heads * head_dim;
3368
3369    writeln!(code, "impl MetalModel {{")?;
3370
3371    // ── new() ──
3372    writeln!(
3373        code,
3374        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3375    )?;
3376    writeln!(code, "    ///")?;
3377    writeln!(
3378        code,
3379        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
3380    )?;
3381    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
3382    writeln!(
3383        code,
3384        "        let device = Device::system_default().expect(\"no Metal device found\");"
3385    )?;
3386    writeln!(code, "        let queue = device.new_command_queue();")?;
3387    writeln!(code)?;
3388
3389    // Compile shaders
3390    writeln!(
3391        code,
3392        "        // Compile Metal shaders from embedded source"
3393    )?;
3394    writeln!(
3395        code,
3396        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
3397    )?;
3398    writeln!(
3399        code,
3400        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3401    )?;
3402    writeln!(
3403        code,
3404        "            .expect(\"failed to compile Metal shaders\");"
3405    )?;
3406    writeln!(code)?;
3407
3408    // Create compute pipelines
3409    // Select the FFN activation kernel. Gemma-1 uses tanh-approximate GeLU
3410    // (`gelu_pytorch_tanh`); everything else uses SiLU. The `silu_mul_fused*`
3411    // pipeline fields are reused — they simply hold the GeLU kernel when
3412    // Gemma is the target.
3413    let activation_kernel = match config.hidden_activation {
3414        HiddenActivation::SiLU => "silu_mul_fused",
3415        HiddenActivation::GeluApprox => "gelu_mul_fused",
3416    };
3417    let activation_kernel_batch = match config.hidden_activation {
3418        HiddenActivation::SiLU => "silu_mul_fused_batch",
3419        HiddenActivation::GeluApprox => "gelu_mul_fused_batch",
3420    };
3421    writeln!(code, "        // Create compute pipelines")?;
3422    for (var, fn_name) in [
3423        ("matmul_pipeline", "matmul_vec"),
3424        ("matmul_q8_pipeline", "matmul_vec_q8"),
3425        ("matmul_q4_pipeline", "matmul_vec_q4"),
3426        ("rms_norm_pipeline", "rms_norm"),
3427        ("rope_pipeline", "rope"),
3428        ("softmax_pipeline", "softmax"),
3429        ("silu_mul_pipeline", "silu_mul"),
3430        ("silu_mul_fused_pipeline", activation_kernel),
3431        ("add_pipeline", "elementwise_add"),
3432        ("attention_pipeline", "attention"),
3433        ("add_inplace_pipeline", "add_inplace"),
3434        ("copy_pipeline", "copy_buffer"),
3435        ("copy_offset_pipeline", "copy_offset"),
3436        ("copy_f32_to_f16_offset_pipeline", "copy_f32_to_f16_offset"),
3437        ("matmul_batch_pipeline", "matmul_vec_batch"),
3438        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3439        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3440        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3441        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3442        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3443        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3444        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3445        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3446        ("rms_norm_batch_pipeline", "rms_norm_batch"),
3447        ("rope_batch_pipeline", "rope_batch"),
3448        ("silu_mul_fused_batch_pipeline", activation_kernel_batch),
3449        ("add_inplace_batch_pipeline", "add_inplace_batch"),
3450        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3451        ("scale_buffer_pipeline", "scale_buffer"),
3452        ("attention_batch_pipeline", "attention_batch"),
3453        ("attention_flash_batch_pipeline", "attention_flash_batch"),
3454        (
3455            "attention_mma_flash_batch_pipeline",
3456            "attention_mma_flash_batch",
3457        ),
3458        ("copy_kv_batch_pipeline", "copy_kv_batch"),
3459        ("rope_qk_batch_pipeline", "rope_qk_batch"),
3460        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3461    ] {
3462        writeln!(
3463            code,
3464            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3465        )?;
3466    }
3467    if config.qkv_bias {
3468        writeln!(
3469            code,
3470            "        let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3471        )?;
3472    }
3473    writeln!(code)?;
3474
3475    // Weight loading
3476    writeln!(
3477        code,
3478        "        // Load weights into Metal shared-memory buffers"
3479    )?;
3480    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
3481    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3482    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
3483    writeln!(code)?;
3484    writeln!(
3485        code,
3486        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
3487    )?;
3488    writeln!(code)?;
3489    writeln!(
3490        code,
3491        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3492    )?;
3493    writeln!(
3494        code,
3495        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3496    )?;
3497    writeln!(code, "            let byte_len = n * f32_size;")?;
3498    writeln!(code, "            let cur = cursor.get();")?;
3499    writeln!(
3500        code,
3501        "            let data = &weights[cur..cur + byte_len];"
3502    )?;
3503    writeln!(code, "            cursor.set(cur + byte_len);")?;
3504    writeln!(code, "            device.new_buffer_with_data(")?;
3505    writeln!(code, "                data.as_ptr() as *const _,")?;
3506    writeln!(code, "                byte_len as u64,")?;
3507    writeln!(
3508        code,
3509        "                MTLResourceOptions::StorageModeShared,"
3510    )?;
3511    writeln!(code, "            )")?;
3512    writeln!(code, "        }};")?;
3513    writeln!(code)?;
3514
3515    if is_q8 {
3516        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
3517        // We load them directly into Metal buffers without dequantizing,
3518        // and use the matmul_vec_q8 shader that operates on quantized data.
3519        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
3520        writeln!(
3521            code,
3522            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3523        )?;
3524        writeln!(
3525            code,
3526            "        // as raw bytes into a Metal buffer (no dequantization)."
3527        )?;
3528        writeln!(
3529            code,
3530            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3531        )?;
3532        writeln!(
3533            code,
3534            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3535        )?;
3536        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3537        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3538        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3539        writeln!(code, "            let cur = cursor.get();")?;
3540        writeln!(
3541            code,
3542            "            let data = &weights[cur..cur + total_raw];"
3543        )?;
3544        writeln!(code, "            cursor.set(cur + total_raw);")?;
3545        writeln!(code, "            device.new_buffer_with_data(")?;
3546        writeln!(code, "                data.as_ptr() as *const _,")?;
3547        writeln!(code, "                total_raw as u64,")?;
3548        writeln!(
3549            code,
3550            "                MTLResourceOptions::StorageModeShared,"
3551        )?;
3552        writeln!(code, "            )")?;
3553        writeln!(code, "        }};")?;
3554        writeln!(code)?;
3555        writeln!(
3556            code,
3557            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3558        )?;
3559        writeln!(
3560            code,
3561            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3562        )?;
3563        writeln!(
3564            code,
3565            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3566        )?;
3567        writeln!(
3568            code,
3569            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3570        )?;
3571        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3572        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3573        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3574        writeln!(code, "            let cur = cursor.get();")?;
3575        writeln!(
3576            code,
3577            "            let data = &weights[cur..cur + total_raw];"
3578        )?;
3579        writeln!(code, "            cursor.set(cur + total_raw);")?;
3580        writeln!(code, "            device.new_buffer_with_data(")?;
3581        writeln!(code, "                data.as_ptr() as *const _,")?;
3582        writeln!(code, "                total_raw as u64,")?;
3583        writeln!(
3584            code,
3585            "                MTLResourceOptions::StorageModeShared,"
3586        )?;
3587        writeln!(code, "            )")?;
3588        writeln!(code, "        }};")?;
3589        writeln!(code)?;
3590    }
3591
3592    if is_q4 {
3593        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3594        // We load them directly into Metal buffers without dequantizing,
3595        // and use the matmul_vec_q4 shader that operates on quantized data.
3596        // This quarters GPU memory usage vs f32 dequantization.
3597        writeln!(
3598            code,
3599            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3600        )?;
3601        writeln!(
3602            code,
3603            "        // as raw bytes into a Metal buffer (no dequantization)."
3604        )?;
3605        writeln!(
3606            code,
3607            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3608        )?;
3609        writeln!(
3610            code,
3611            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3612        )?;
3613        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3614        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3615        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3616        writeln!(code, "            let cur = cursor.get();")?;
3617        writeln!(
3618            code,
3619            "            let data = &weights[cur..cur + total_raw];"
3620        )?;
3621        writeln!(code, "            cursor.set(cur + total_raw);")?;
3622        writeln!(code, "            device.new_buffer_with_data(")?;
3623        writeln!(code, "                data.as_ptr() as *const _,")?;
3624        writeln!(code, "                total_raw as u64,")?;
3625        writeln!(
3626            code,
3627            "                MTLResourceOptions::StorageModeShared,"
3628        )?;
3629        writeln!(code, "            )")?;
3630        writeln!(code, "        }};")?;
3631        writeln!(code)?;
3632        writeln!(
3633            code,
3634            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3635        )?;
3636        writeln!(
3637            code,
3638            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3639        )?;
3640        writeln!(
3641            code,
3642            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3643        )?;
3644        writeln!(
3645            code,
3646            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3647        )?;
3648        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3649        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3650        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3651        writeln!(code, "            let cur = cursor.get();")?;
3652        writeln!(
3653            code,
3654            "            let data = &weights[cur..cur + total_raw];"
3655        )?;
3656        writeln!(code, "            cursor.set(cur + total_raw);")?;
3657        writeln!(code, "            device.new_buffer_with_data(")?;
3658        writeln!(code, "                data.as_ptr() as *const _,")?;
3659        writeln!(code, "                total_raw as u64,")?;
3660        writeln!(
3661            code,
3662            "                MTLResourceOptions::StorageModeShared,"
3663        )?;
3664        writeln!(code, "            )")?;
3665        writeln!(code, "        }};")?;
3666        writeln!(code)?;
3667    }
3668
3669    writeln!(
3670        code,
3671        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3672    )?;
3673    writeln!(code)?;
3674
3675    // Per-layer weights
3676    writeln!(
3677        code,
3678        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3679    )?;
3680    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3681
3682    // attn_norm is always f32
3683    writeln!(
3684        code,
3685        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3686    )?;
3687
3688    let qkv_rows = hidden + 2 * kv_dim;
3689    if is_q8 {
3690        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3691        writeln!(
3692            code,
3693            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3694        )?;
3695        if config.qkv_bias {
3696            writeln!(
3697                code,
3698                "            // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3699            )?;
3700            writeln!(
3701                code,
3702                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3703            )?;
3704        }
3705        writeln!(
3706            code,
3707            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3708        )?;
3709    } else if is_q4 {
3710        writeln!(
3711            code,
3712            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3713        )?;
3714        if config.qkv_bias {
3715            writeln!(
3716                code,
3717                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3718            )?;
3719        }
3720        writeln!(
3721            code,
3722            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3723        )?;
3724    } else {
3725        writeln!(
3726            code,
3727            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3728        )?;
3729        if config.qkv_bias {
3730            writeln!(
3731                code,
3732                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3733            )?;
3734        }
3735        writeln!(
3736            code,
3737            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3738        )?;
3739    }
3740
3741    // ffn_norm is always f32
3742    writeln!(
3743        code,
3744        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3745    )?;
3746
3747    let gate_up_rows = 2 * intermediate;
3748    if is_q8 {
3749        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3750        writeln!(
3751            code,
3752            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3753        )?;
3754        writeln!(
3755            code,
3756            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3757        )?;
3758    } else if is_q4 {
3759        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3760        writeln!(
3761            code,
3762            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3763        )?;
3764        writeln!(
3765            code,
3766            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3767        )?;
3768    } else {
3769        // Fused gate+up weight: read both as a single contiguous f32 buffer
3770        writeln!(
3771            code,
3772            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3773        )?;
3774        writeln!(
3775            code,
3776            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3777        )?;
3778    }
3779
3780    writeln!(code, "            layers.push(LayerBuffers {{")?;
3781    writeln!(code, "                attn_norm,")?;
3782    writeln!(code, "                qkv_weight,")?;
3783    if config.qkv_bias {
3784        writeln!(code, "                qkv_bias,")?;
3785    }
3786    writeln!(code, "                o_weight,")?;
3787    writeln!(code, "                ffn_norm,")?;
3788    writeln!(code, "                gate_up_weight,")?;
3789    writeln!(code, "                down_weight,")?;
3790    writeln!(code, "            }});")?;
3791    writeln!(code, "        }}")?;
3792    writeln!(code)?;
3793
3794    // final_norm is always f32
3795    writeln!(
3796        code,
3797        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3798    )?;
3799    writeln!(code)?;
3800
3801    // lm_head
3802    if is_q8 {
3803        writeln!(
3804            code,
3805            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3806        )?;
3807    } else if is_q4 {
3808        writeln!(
3809            code,
3810            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3811        )?;
3812    } else {
3813        writeln!(
3814            code,
3815            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3816        )?;
3817    }
3818    writeln!(code)?;
3819
3820    // Working buffers
3821    let hidden_bytes = hidden * 4;
3822    let _kv_dim_bytes = kv_dim * 4;
3823    let intermediate_bytes = intermediate * 4;
3824    let vocab_bytes = vocab * 4;
3825    // KV cache is stored as f16 (2 bytes/element) to halve attention memory
3826    // bandwidth in long-context decode.
3827    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 2;
3828
3829    writeln!(
3830        code,
3831        "        // Allocate working buffers (shared memory for zero-copy)"
3832    )?;
3833    writeln!(
3834        code,
3835        "        let opts = MTLResourceOptions::StorageModeShared;"
3836    )?;
3837    writeln!(
3838        code,
3839        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3840    )?;
3841    writeln!(
3842        code,
3843        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3844    )?;
3845    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3846    writeln!(
3847        code,
3848        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3849    )?;
3850    writeln!(
3851        code,
3852        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3853    )?;
3854    writeln!(
3855        code,
3856        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3857    )?;
3858    writeln!(
3859        code,
3860        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3861    )?;
3862    writeln!(
3863        code,
3864        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3865    )?;
3866    let gate_up_buf_bytes = 2 * intermediate * 4;
3867    writeln!(
3868        code,
3869        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3870    )?;
3871    writeln!(
3872        code,
3873        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3874    )?;
3875    writeln!(
3876        code,
3877        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3878    )?;
3879    writeln!(
3880        code,
3881        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3882    )?;
3883    writeln!(
3884        code,
3885        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3886    )?;
3887    writeln!(
3888        code,
3889        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3890    )?;
3891    writeln!(code)?;
3892
3893    // Batch prefill working buffers
3894    let batch_hidden_bytes = hidden * 4; // per-token
3895    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3896    let batch_gate_up_bytes = 2 * intermediate * 4;
3897    let batch_intermediate_bytes = intermediate * 4;
3898    writeln!(
3899        code,
3900        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3901    )?;
3902    writeln!(
3903        code,
3904        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3905    )?;
3906    writeln!(
3907        code,
3908        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3909    )?;
3910    writeln!(
3911        code,
3912        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3913    )?;
3914    writeln!(
3915        code,
3916        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3917    )?;
3918    writeln!(
3919        code,
3920        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3921    )?;
3922    writeln!(
3923        code,
3924        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3925    )?;
3926    writeln!(
3927        code,
3928        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3929    )?;
3930    writeln!(
3931        code,
3932        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3933    )?;
3934    writeln!(
3935        code,
3936        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3937    )?;
3938    writeln!(
3939        code,
3940        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3941    )?;
3942    writeln!(code)?;
3943
3944    // KV cache buffers
3945    writeln!(code, "        // KV cache buffers (per-layer)")?;
3946    writeln!(
3947        code,
3948        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3949    )?;
3950    writeln!(
3951        code,
3952        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3953    )?;
3954    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3955    writeln!(
3956        code,
3957        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3958    )?;
3959    writeln!(
3960        code,
3961        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3962    )?;
3963    writeln!(code, "        }}")?;
3964    writeln!(code)?;
3965
3966    writeln!(code, "        Self {{")?;
3967    writeln!(code, "            device,")?;
3968    writeln!(code, "            queue,")?;
3969    writeln!(code, "            matmul_pipeline,")?;
3970    writeln!(code, "            matmul_q8_pipeline,")?;
3971    writeln!(code, "            matmul_q4_pipeline,")?;
3972    writeln!(code, "            rms_norm_pipeline,")?;
3973    writeln!(code, "            rope_pipeline,")?;
3974    writeln!(code, "            softmax_pipeline,")?;
3975    writeln!(code, "            silu_mul_pipeline,")?;
3976    writeln!(code, "            silu_mul_fused_pipeline,")?;
3977    writeln!(code, "            add_pipeline,")?;
3978    writeln!(code, "            attention_pipeline,")?;
3979    writeln!(code, "            add_inplace_pipeline,")?;
3980    writeln!(code, "            copy_pipeline,")?;
3981    writeln!(code, "            copy_offset_pipeline,")?;
3982    writeln!(code, "            copy_f32_to_f16_offset_pipeline,")?;
3983    writeln!(code, "            matmul_batch_pipeline,")?;
3984    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3985    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3986    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3987    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3988    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3989    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3990    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3991    if config.qkv_bias {
3992        writeln!(code, "            add_bias_batch_pipeline,")?;
3993    }
3994    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3995    writeln!(code, "            rms_norm_batch_pipeline,")?;
3996    writeln!(code, "            rope_batch_pipeline,")?;
3997    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3998    writeln!(code, "            add_inplace_batch_pipeline,")?;
3999    writeln!(code, "            copy_embedding_batch_pipeline,")?;
4000    writeln!(code, "            scale_buffer_pipeline,")?;
4001    writeln!(code, "            attention_batch_pipeline,")?;
4002    writeln!(code, "            attention_flash_batch_pipeline,")?;
4003    writeln!(code, "            attention_mma_flash_batch_pipeline,")?;
4004    writeln!(code, "            copy_kv_batch_pipeline,")?;
4005    writeln!(code, "            rope_qk_batch_pipeline,")?;
4006    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
4007    writeln!(code, "            embed_buf,")?;
4008    writeln!(code, "            layers,")?;
4009    writeln!(code, "            norm_buf,")?;
4010    writeln!(code, "            lm_head_buf,")?;
4011    writeln!(code, "            hidden_buf,")?;
4012    writeln!(code, "            residual_buf,")?;
4013    writeln!(code, "            normed_buf,")?;
4014    writeln!(code, "            qkv_buf,")?;
4015    writeln!(code, "            attn_out_buf,")?;
4016    writeln!(code, "            attn_proj_buf,")?;
4017    writeln!(code, "            gate_up_buf,")?;
4018    writeln!(code, "            ffn_hidden_buf,")?;
4019    writeln!(code, "            ffn_out_buf,")?;
4020    writeln!(code, "            add_tmp_buf,")?;
4021    writeln!(code, "            logits_buf,")?;
4022    writeln!(code, "            batch_hidden_buf,")?;
4023    writeln!(code, "            batch_residual_buf,")?;
4024    writeln!(code, "            batch_qkv_buf,")?;
4025    writeln!(code, "            batch_attn_out_buf,")?;
4026    writeln!(code, "            batch_attn_proj_buf,")?;
4027    writeln!(code, "            batch_gate_up_buf,")?;
4028    writeln!(code, "            batch_ffn_hidden_buf,")?;
4029    writeln!(code, "            batch_ffn_out_buf,")?;
4030    writeln!(code, "            batch_tokens_buf,")?;
4031    writeln!(code, "            batch_positions_buf,")?;
4032    writeln!(code, "            k_cache,")?;
4033    writeln!(code, "            v_cache,")?;
4034    writeln!(code, "            pos: 0,")?;
4035    writeln!(code, "            prev_cmd: None,")?;
4036    writeln!(code, "        }}")?;
4037    writeln!(code, "    }}")?;
4038    writeln!(code)?;
4039
4040    // ── forward() ──
4041    writeln!(
4042        code,
4043        "    /// Run the forward pass for a single token at the current position."
4044    )?;
4045    writeln!(code, "    ///")?;
4046    writeln!(
4047        code,
4048        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
4049    )?;
4050    writeln!(code, "    ///")?;
4051    writeln!(
4052        code,
4053        "    /// All GPU operations are encoded into a single command buffer and"
4054    )?;
4055    writeln!(
4056        code,
4057        "    /// committed once at the end, avoiding per-operation synchronization."
4058    )?;
4059    writeln!(
4060        code,
4061        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
4062    )?;
4063    writeln!(
4064        code,
4065        "        // Wait for any pending prefill command buffer"
4066    )?;
4067    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4068    writeln!(code, "            prev.wait_until_completed();")?;
4069    writeln!(code, "        }}")?;
4070    writeln!(code)?;
4071    writeln!(code, "        let pos = self.pos;")?;
4072    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4073    writeln!(code)?;
4074
4075    // Single compute encoder for the entire forward pass — no blit encoder
4076    // transitions. Copy operations use compute copy kernels instead of blits.
4077    let matmul_fn = if is_q8 {
4078        "dispatch_matmul_q8"
4079    } else if is_q4 {
4080        "dispatch_matmul_q4"
4081    } else {
4082        "dispatch_matmul"
4083    };
4084
4085    writeln!(
4086        code,
4087        "        // Single compute encoder for the entire forward pass (no blit transitions)"
4088    )?;
4089    writeln!(code, "        {{")?;
4090    writeln!(
4091        code,
4092        "            let enc = cmd.new_compute_command_encoder();"
4093    )?;
4094    writeln!(code)?;
4095
4096    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
4097    writeln!(
4098        code,
4099        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
4100    )?;
4101    writeln!(
4102        code,
4103        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
4104    )?;
4105    writeln!(
4106        code,
4107        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
4108    )?;
4109    writeln!(
4110        code,
4111        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
4112        hidden * 4,
4113    )?;
4114    writeln!(code, "            unsafe {{")?;
4115    writeln!(
4116        code,
4117        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
4118    )?;
4119    writeln!(
4120        code,
4121        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4122    )?;
4123    writeln!(
4124        code,
4125        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
4126    )?;
4127    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
4128    writeln!(
4129        code,
4130        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4131    )?;
4132    writeln!(code, "                    hidden_ptr,")?;
4133    writeln!(code, "                    HIDDEN_SIZE,")?;
4134    writeln!(code, "                );")?;
4135    if matches!(config.hidden_activation, HiddenActivation::GeluApprox) {
4136        // Gemma-1 multiplies the embedding by sqrt(hidden_size) after lookup.
4137        // Applied here rather than at weight load time so tied embeddings
4138        // (lm_head shares the `embed_tokens` buffer) keep the logit projection
4139        // unscaled.
4140        writeln!(
4141            code,
4142            "                const GEMMA_EMBED_SCALE: f32 = {scale:e}_f32;",
4143            scale = (hidden as f32).sqrt(),
4144        )?;
4145        writeln!(code, "                for i in 0..HIDDEN_SIZE {{")?;
4146        writeln!(
4147            code,
4148            "                    *hidden_ptr.add(i) *= GEMMA_EMBED_SCALE;"
4149        )?;
4150        writeln!(code, "                }}")?;
4151    }
4152    writeln!(
4153        code,
4154        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4155    )?;
4156    writeln!(code, "            }}")?;
4157    writeln!(code)?;
4158
4159    // 2. Transformer layers
4160    writeln!(code, "            // 2. Transformer layers")?;
4161    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4162    writeln!(code)?;
4163    let q_byte_offset = 0usize;
4164    let k_float_offset = hidden;
4165    let v_float_offset = hidden + kv_dim;
4166
4167    writeln!(
4168        code,
4169        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
4170    )?;
4171    writeln!(
4172        code,
4173        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4174    )?;
4175    writeln!(
4176        code,
4177        "                // Fused Q+K+V matmul: single dispatch for all three projections"
4178    )?;
4179    writeln!(
4180        code,
4181        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4182    )?;
4183    if config.qkv_bias {
4184        writeln!(
4185            code,
4186            "                // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4187        )?;
4188        writeln!(
4189            code,
4190            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4191        )?;
4192    }
4193    writeln!(
4194        code,
4195        "                // Fused Q+K RoPE in one dispatch (saves 1 dispatch + barrier vs separate Q and K rope)"
4196    )?;
4197    writeln!(
4198        code,
4199        "                self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4200    )?;
4201    writeln!(code)?;
4202    writeln!(
4203        code,
4204        "                // Fused K+V cache update in one dispatch (f32 qkv_buf -> f16 KV cache)"
4205    )?;
4206    writeln!(
4207        code,
4208        "                self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4209    )?;
4210    writeln!(code)?;
4211    writeln!(
4212        code,
4213        "                // Attention using Q from qkv_buf (offset 0)"
4214    )?;
4215    writeln!(
4216        code,
4217        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4218    )?;
4219    writeln!(
4220        code,
4221        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4222    )?;
4223    writeln!(
4224        code,
4225        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4226    )?;
4227    writeln!(
4228        code,
4229        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4230    )?;
4231    writeln!(
4232        code,
4233        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4234    )?;
4235    writeln!(
4236        code,
4237        "                // Fused gate+up matmul: single dispatch for both projections"
4238    )?;
4239    writeln!(
4240        code,
4241        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4242    )?;
4243    writeln!(
4244        code,
4245        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4246    )?;
4247    writeln!(
4248        code,
4249        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4250    )?;
4251    writeln!(
4252        code,
4253        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4254    )?;
4255    writeln!(code, "            }}")?;
4256    writeln!(code)?;
4257
4258    // 3. Final RMS norm + logits
4259    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
4260    writeln!(
4261        code,
4262        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4263    )?;
4264    writeln!(
4265        code,
4266        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4267    )?;
4268    writeln!(code)?;
4269    writeln!(code, "            enc.end_encoding();")?;
4270    writeln!(code, "        }}")?;
4271    writeln!(code)?;
4272
4273    // 5. Single commit + wait, then read back logits
4274    writeln!(
4275        code,
4276        "        // 5. Commit all GPU work and wait for completion"
4277    )?;
4278    writeln!(code, "        cmd.commit();")?;
4279    writeln!(code, "        cmd.wait_until_completed();")?;
4280    writeln!(code)?;
4281    writeln!(code, "        // 6. Read back logits from GPU")?;
4282    writeln!(code, "        let logits = unsafe {{")?;
4283    writeln!(
4284        code,
4285        "            let ptr = self.logits_buf.contents() as *const f32;"
4286    )?;
4287    writeln!(
4288        code,
4289        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4290    )?;
4291    writeln!(code, "        }};")?;
4292    writeln!(code)?;
4293    writeln!(code, "        self.pos += 1;")?;
4294    writeln!(code, "        logits")?;
4295    writeln!(code, "    }}")?;
4296    writeln!(code)?;
4297
4298    // ── forward_profile: instrumented forward with per-operation timing ──
4299    writeln!(
4300        code,
4301        "    /// Profiling forward pass that prints per-stage GPU timing."
4302    )?;
4303    writeln!(code, "    ///")?;
4304    writeln!(
4305        code,
4306        "    /// Each stage is committed and waited on separately so that GPU timestamps"
4307    )?;
4308    writeln!(
4309        code,
4310        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
4311    )?;
4312    writeln!(
4313        code,
4314        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
4315    )?;
4316    writeln!(
4317        code,
4318        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4319    )?;
4320    writeln!(code, "        use std::time::Instant;")?;
4321    writeln!(code)?;
4322    writeln!(
4323        code,
4324        "        // Wait for any pending prefill command buffer"
4325    )?;
4326    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4327    writeln!(code, "            prev.wait_until_completed();")?;
4328    writeln!(code, "        }}")?;
4329    writeln!(code)?;
4330    writeln!(code, "        let pos = self.pos;")?;
4331    writeln!(code)?;
4332
4333    // Stage: embedding (CPU, no GPU)
4334    writeln!(
4335        code,
4336        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
4337    )?;
4338    writeln!(code, "        let t_embed = Instant::now();")?;
4339    writeln!(code, "        unsafe {{")?;
4340    writeln!(
4341        code,
4342        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
4343    )?;
4344    writeln!(
4345        code,
4346        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4347    )?;
4348    writeln!(
4349        code,
4350        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
4351    )?;
4352    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
4353    writeln!(
4354        code,
4355        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4356    )?;
4357    writeln!(code, "                hidden_ptr,")?;
4358    writeln!(code, "                HIDDEN_SIZE,")?;
4359    writeln!(code, "            );")?;
4360    if matches!(config.hidden_activation, HiddenActivation::GeluApprox) {
4361        writeln!(
4362            code,
4363            "            const GEMMA_EMBED_SCALE: f32 = {scale:e}_f32;",
4364            scale = (hidden as f32).sqrt(),
4365        )?;
4366        writeln!(code, "            for i in 0..HIDDEN_SIZE {{")?;
4367        writeln!(
4368            code,
4369            "                *hidden_ptr.add(i) *= GEMMA_EMBED_SCALE;"
4370        )?;
4371        writeln!(code, "            }}")?;
4372    }
4373    writeln!(
4374        code,
4375        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4376    )?;
4377    writeln!(code, "        }}")?;
4378    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
4379    writeln!(code)?;
4380
4381    // Stage: Transformer layers (all together on GPU)
4382    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
4383    writeln!(code, "        let t_layers = Instant::now();")?;
4384    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4385    writeln!(code, "        {{")?;
4386    writeln!(
4387        code,
4388        "            let enc = cmd.new_compute_command_encoder();"
4389    )?;
4390    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4391    writeln!(
4392        code,
4393        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4394    )?;
4395    writeln!(
4396        code,
4397        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4398    )?;
4399    if config.qkv_bias {
4400        writeln!(
4401            code,
4402            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4403        )?;
4404    }
4405    writeln!(
4406        code,
4407        "                self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4408    )?;
4409    writeln!(
4410        code,
4411        "                self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4412    )?;
4413    writeln!(
4414        code,
4415        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4416    )?;
4417    writeln!(
4418        code,
4419        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4420    )?;
4421    writeln!(
4422        code,
4423        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4424    )?;
4425    writeln!(
4426        code,
4427        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4428    )?;
4429    writeln!(
4430        code,
4431        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4432    )?;
4433    writeln!(
4434        code,
4435        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4436    )?;
4437    writeln!(
4438        code,
4439        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4440    )?;
4441    writeln!(
4442        code,
4443        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4444    )?;
4445    writeln!(code, "            }}")?;
4446    writeln!(code, "            enc.end_encoding();")?;
4447    writeln!(code, "        }}")?;
4448    writeln!(code, "        cmd.commit();")?;
4449    writeln!(code, "        cmd.wait_until_completed();")?;
4450    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
4451    writeln!(code)?;
4452
4453    // Stage: Final norm + logits
4454    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
4455    writeln!(code, "        let t_logits = Instant::now();")?;
4456    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
4457    writeln!(code, "        {{")?;
4458    writeln!(
4459        code,
4460        "            let enc = cmd2.new_compute_command_encoder();"
4461    )?;
4462    writeln!(
4463        code,
4464        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4465    )?;
4466    writeln!(
4467        code,
4468        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4469    )?;
4470    writeln!(code, "            enc.end_encoding();")?;
4471    writeln!(code, "        }}")?;
4472    writeln!(code, "        cmd2.commit();")?;
4473    writeln!(code, "        cmd2.wait_until_completed();")?;
4474    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
4475    writeln!(code)?;
4476
4477    // Print profile results
4478    writeln!(
4479        code,
4480        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4481    )?;
4482    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
4483    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
4484    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
4485    writeln!(
4486        code,
4487        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4488    )?;
4489    writeln!(code)?;
4490
4491    // Read back logits
4492    writeln!(code, "        let logits = unsafe {{")?;
4493    writeln!(
4494        code,
4495        "            let ptr = self.logits_buf.contents() as *const f32;"
4496    )?;
4497    writeln!(
4498        code,
4499        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4500    )?;
4501    writeln!(code, "        }};")?;
4502    writeln!(code)?;
4503    writeln!(code, "        self.pos += 1;")?;
4504    writeln!(code, "        logits")?;
4505    writeln!(code, "    }}")?;
4506    writeln!(code)?;
4507
4508    // ── forward_prefill: single-token async forward (backward compat) ──
4509    writeln!(
4510        code,
4511        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
4512    )?;
4513    writeln!(code, "    ///")?;
4514    writeln!(
4515        code,
4516        "    /// Commits the command buffer without waiting, enabling double-buffered"
4517    )?;
4518    writeln!(
4519        code,
4520        "    /// execution: GPU processes token N while CPU encodes token N+1."
4521    )?;
4522    writeln!(
4523        code,
4524        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
4525    )?;
4526    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
4527    writeln!(code, "    }}")?;
4528    writeln!(code)?;
4529
4530    // ── forward_prefill_batch: batched prefill for multiple tokens ──
4531    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
4532    let batch_matmul_fn = if is_q8 {
4533        "dispatch_matmul_q8_batch"
4534    } else if is_q4 {
4535        "dispatch_matmul_q4_batch"
4536    } else {
4537        "dispatch_matmul_batch"
4538    };
4539
4540    writeln!(
4541        code,
4542        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4543    )?;
4544    writeln!(code, "    ///")?;
4545    writeln!(
4546        code,
4547        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4548    )?;
4549    writeln!(
4550        code,
4551        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
4552    )?;
4553    writeln!(
4554        code,
4555        "    /// This provides significant speedup during prompt prefill."
4556    )?;
4557    writeln!(
4558        code,
4559        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4560    )?;
4561    writeln!(code, "        if tokens.is_empty() {{ return; }}")?;
4562    writeln!(
4563        code,
4564        "        // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4565    )?;
4566    writeln!(
4567        code,
4568        "        // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4569    )?;
4570    writeln!(
4571        code,
4572        "        // longer than that must be processed iteratively.  The KV cache"
4573    )?;
4574    writeln!(code, "        // carries state across chunks via self.pos.")?;
4575    writeln!(
4576        code,
4577        "        for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4578    )?;
4579    writeln!(code, "        let m = chunk.len();")?;
4580    writeln!(code, "        if m == 0 {{ continue; }}")?;
4581    writeln!(code, "        let start_pos = self.pos;")?;
4582    writeln!(code)?;
4583    writeln!(code, "        // Wait for any pending command buffer")?;
4584    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4585    writeln!(code, "            prev.wait_until_completed();")?;
4586    writeln!(code, "        }}")?;
4587    writeln!(code)?;
4588
4589    // Upload token IDs and positions to GPU
4590    writeln!(
4591        code,
4592        "        // Upload token IDs and positions to GPU buffers"
4593    )?;
4594    writeln!(code, "        unsafe {{")?;
4595    writeln!(
4596        code,
4597        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4598    )?;
4599    writeln!(
4600        code,
4601        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4602    )?;
4603    writeln!(code, "            for i in 0..m {{")?;
4604    writeln!(code, "                *tok_ptr.add(i) = chunk[i];")?;
4605    writeln!(
4606        code,
4607        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
4608    )?;
4609    writeln!(code, "            }}")?;
4610    writeln!(code, "        }}")?;
4611    writeln!(code)?;
4612
4613    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4614    writeln!(code, "        {{")?;
4615    writeln!(
4616        code,
4617        "            let enc = cmd.new_compute_command_encoder();"
4618    )?;
4619    writeln!(code)?;
4620
4621    // 1. Batch embedding lookup
4622    writeln!(
4623        code,
4624        "            // 1. Batch embedding lookup: copy all token embeddings at once"
4625    )?;
4626    writeln!(
4627        code,
4628        "            self.dispatch_copy_embedding_batch(&enc, m);"
4629    )?;
4630    if matches!(config.hidden_activation, HiddenActivation::GeluApprox) {
4631        // Gemma-1: scale embeddings by sqrt(hidden_size) after lookup.
4632        writeln!(
4633            code,
4634            "            // Gemma-1 embedding scale by sqrt(hidden_size)"
4635        )?;
4636        writeln!(
4637            code,
4638            "            self.dispatch_scale_buffer(&enc, &self.batch_hidden_buf, {scale:e}_f32, m * {hidden});",
4639            scale = (hidden as f32).sqrt(),
4640        )?;
4641    }
4642    // Copy batch_hidden -> batch_residual
4643    writeln!(
4644        code,
4645        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4646    )?;
4647    writeln!(code)?;
4648
4649    // 2. Transformer layers
4650    writeln!(code, "            // 2. Transformer layers")?;
4651    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4652    writeln!(code)?;
4653
4654    // Batch RMS norm: residual -> hidden (batched)
4655    writeln!(
4656        code,
4657        "                // Batch RMS norm: batch_residual -> batch_hidden"
4658    )?;
4659    writeln!(
4660        code,
4661        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4662    )?;
4663
4664    // Batch QKV matmul
4665    writeln!(
4666        code,
4667        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4668    )?;
4669    writeln!(
4670        code,
4671        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4672    )?;
4673    if config.qkv_bias {
4674        writeln!(
4675            code,
4676            "                // Qwen2: broadcast-add QKV bias across all M tokens."
4677        )?;
4678        writeln!(
4679            code,
4680            "                self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4681        )?;
4682    }
4683    writeln!(code)?;
4684
4685    // Fused RoPE on Q+K portions in a single dispatch
4686    let k_float_offset = hidden;
4687    writeln!(
4688        code,
4689        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4690    )?;
4691    writeln!(
4692        code,
4693        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4694    )?;
4695    writeln!(code)?;
4696
4697    // Fused KV cache update: copy both K and V in a single dispatch
4698    let v_float_offset = hidden + kv_dim;
4699    writeln!(
4700        code,
4701        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4702    )?;
4703    writeln!(
4704        code,
4705        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4706    )?;
4707    writeln!(code)?;
4708
4709    // Batched causal attention: ONE dispatch for all M tokens
4710    writeln!(
4711        code,
4712        "                // Batched causal attention: one dispatch for all M tokens"
4713    )?;
4714    writeln!(
4715        code,
4716        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4717    )?;
4718    writeln!(code)?;
4719
4720    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4721    writeln!(code, "                // Batched O projection")?;
4722    writeln!(
4723        code,
4724        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4725    )?;
4726    writeln!(code)?;
4727
4728    // Batch add: residual += attn_proj for all tokens
4729    writeln!(
4730        code,
4731        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4732    )?;
4733    writeln!(code)?;
4734
4735    // Batch FFN
4736    writeln!(
4737        code,
4738        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4739    )?;
4740    writeln!(
4741        code,
4742        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4743    )?;
4744    writeln!(
4745        code,
4746        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4747    )?;
4748    writeln!(
4749        code,
4750        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4751    )?;
4752    writeln!(
4753        code,
4754        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4755    )?;
4756    writeln!(
4757        code,
4758        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4759    )?;
4760    writeln!(code, "            }}")?;
4761    writeln!(code)?;
4762
4763    // Copy last token's residual to single-token residual_buf for next forward() call
4764    writeln!(
4765        code,
4766        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4767    )?;
4768    writeln!(
4769        code,
4770        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4771    )?;
4772    writeln!(code)?;
4773    writeln!(code, "            enc.end_encoding();")?;
4774    writeln!(code, "        }}")?;
4775    writeln!(code)?;
4776
4777    writeln!(code, "        cmd.commit();")?;
4778    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4779    writeln!(code, "        self.pos += m;")?;
4780    writeln!(code, "        }}  // end for chunk")?;
4781    writeln!(code, "    }}")?;
4782    writeln!(code)?;
4783
4784    // ── reset() — rewind KV cache position for new inference requests ──
4785    writeln!(
4786        code,
4787        "    /// Reset the model state for a new inference request."
4788    )?;
4789    writeln!(code, "    pub fn reset(&mut self) {{")?;
4790    writeln!(code, "        self.pos = 0;")?;
4791    writeln!(code, "        self.prev_cmd = None;")?;
4792    writeln!(code, "    }}")?;
4793    writeln!(code)?;
4794
4795    // ── Private dispatch helpers (all take a shared compute encoder) ──
4796    writeln!(
4797        code,
4798        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4799    )?;
4800    writeln!(
4801        code,
4802        "    // These methods set pipeline state + buffers + dispatch on an existing"
4803    )?;
4804    writeln!(
4805        code,
4806        "    // encoder, avoiding per-operation encoder creation overhead."
4807    )?;
4808    writeln!(code)?;
4809
4810    // dispatch_rms_norm
4811    writeln!(
4812        code,
4813        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4814    )?;
4815    writeln!(
4816        code,
4817        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4818    )?;
4819    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4820    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4821    writeln!(
4822        code,
4823        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4824    )?;
4825    writeln!(
4826        code,
4827        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4828    )?;
4829    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4830    writeln!(
4831        code,
4832        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4833    )?;
4834    writeln!(
4835        code,
4836        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4837    )?;
4838    writeln!(
4839        code,
4840        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4841    )?;
4842    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4843    writeln!(
4844        code,
4845        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4846    )?;
4847    writeln!(
4848        code,
4849        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4850    )?;
4851    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4852    writeln!(code, "    }}")?;
4853    writeln!(code)?;
4854
4855    // dispatch_matmul
4856    writeln!(
4857        code,
4858        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4859    )?;
4860    writeln!(
4861        code,
4862        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4863    )?;
4864    writeln!(code, "        let r: u32 = rows as u32;")?;
4865    writeln!(code, "        let c: u32 = cols as u32;")?;
4866    writeln!(
4867        code,
4868        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4869    )?;
4870    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4871    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4872    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4873    writeln!(
4874        code,
4875        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4876    )?;
4877    writeln!(
4878        code,
4879        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4880    )?;
4881    writeln!(
4882        code,
4883        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4884    )?;
4885    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4886    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4887    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4888    writeln!(
4889        code,
4890        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4891    )?;
4892    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4893    writeln!(code, "    }}")?;
4894    writeln!(code)?;
4895
4896    // dispatch_matmul_q8
4897    writeln!(
4898        code,
4899        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4900    )?;
4901    writeln!(
4902        code,
4903        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4904    )?;
4905    writeln!(
4906        code,
4907        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4908    )?;
4909    writeln!(code, "        let r: u32 = rows as u32;")?;
4910    writeln!(code, "        let c: u32 = cols as u32;")?;
4911    // matmul_q8 caches the input vector in a 8192-float threadgroup tile;
4912    // for cols > 8192 (e.g. Gemma-2B's 16384-wide down-proj) the tile
4913    // overflows. Route large-col calls through the gemm kernel with M=1 —
4914    // it reads inputs directly from device memory without caching.
4915    writeln!(code, "        if cols > 8192 {{")?;
4916    writeln!(code, "            let nt: u32 = 1u32;")?;
4917    writeln!(
4918        code,
4919        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
4920    )?;
4921    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
4922    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
4923    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
4924    writeln!(
4925        code,
4926        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4927    )?;
4928    writeln!(
4929        code,
4930        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4931    )?;
4932    writeln!(
4933        code,
4934        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4935    )?;
4936    writeln!(code, "            let row_tgs = ((rows + 31) / 32) as u64;")?;
4937    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
4938    writeln!(
4939        code,
4940        "            let grid_size = MTLSize::new(row_tgs, 1, 1);"
4941    )?;
4942    writeln!(
4943        code,
4944        "            enc.dispatch_thread_groups(grid_size, tg_size);"
4945    )?;
4946    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4947    writeln!(code, "            return;")?;
4948    writeln!(code, "        }}")?;
4949    writeln!(
4950        code,
4951        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4952    )?;
4953    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4954    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4955    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4956    writeln!(
4957        code,
4958        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4959    )?;
4960    writeln!(
4961        code,
4962        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4963    )?;
4964    writeln!(
4965        code,
4966        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4967    )?;
4968    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4969    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4970    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4971    writeln!(
4972        code,
4973        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4974    )?;
4975    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4976    writeln!(code, "    }}")?;
4977    writeln!(code)?;
4978
4979    // dispatch_matmul_q4
4980    writeln!(
4981        code,
4982        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4983    )?;
4984    writeln!(
4985        code,
4986        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4987    )?;
4988    writeln!(
4989        code,
4990        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4991    )?;
4992    writeln!(code, "        let r: u32 = rows as u32;")?;
4993    writeln!(code, "        let c: u32 = cols as u32;")?;
4994    writeln!(
4995        code,
4996        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4997    )?;
4998    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4999    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5000    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5001    writeln!(
5002        code,
5003        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5004    )?;
5005    writeln!(
5006        code,
5007        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5008    )?;
5009    writeln!(
5010        code,
5011        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
5012    )?;
5013    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5014    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
5015    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5016    writeln!(
5017        code,
5018        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5019    )?;
5020    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5021    writeln!(code, "    }}")?;
5022    writeln!(code)?;
5023
5024    // dispatch_rope
5025    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
5026    writeln!(
5027        code,
5028        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
5029    )?;
5030    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5031    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5032    writeln!(code, "        let p: u32 = pos as u32;")?;
5033    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5034    writeln!(
5035        code,
5036        "        let total_pairs = num_heads * (head_dim / 2);"
5037    )?;
5038    writeln!(
5039        code,
5040        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
5041    )?;
5042    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
5043    writeln!(
5044        code,
5045        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5046    )?;
5047    writeln!(
5048        code,
5049        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5050    )?;
5051    writeln!(
5052        code,
5053        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5054    )?;
5055    writeln!(
5056        code,
5057        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5058    )?;
5059    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5060    writeln!(
5061        code,
5062        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5063    )?;
5064    writeln!(
5065        code,
5066        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5067    )?;
5068    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5069    writeln!(code, "    }}")?;
5070    writeln!(code)?;
5071
5072    // dispatch_rope_offset
5073    writeln!(
5074        code,
5075        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
5076    )?;
5077    writeln!(
5078        code,
5079        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
5080    )?;
5081    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5082    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5083    writeln!(code, "        let p: u32 = pos as u32;")?;
5084    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5085    writeln!(
5086        code,
5087        "        let total_pairs = num_heads * (head_dim / 2);"
5088    )?;
5089    writeln!(
5090        code,
5091        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
5092    )?;
5093    writeln!(
5094        code,
5095        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
5096    )?;
5097    writeln!(
5098        code,
5099        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5100    )?;
5101    writeln!(
5102        code,
5103        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5104    )?;
5105    writeln!(
5106        code,
5107        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5108    )?;
5109    writeln!(
5110        code,
5111        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5112    )?;
5113    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5114    writeln!(
5115        code,
5116        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5117    )?;
5118    writeln!(
5119        code,
5120        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5121    )?;
5122    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5123    writeln!(code, "    }}")?;
5124    writeln!(code)?;
5125
5126    // dispatch_attention
5127    writeln!(
5128        code,
5129        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
5130    )?;
5131    writeln!(
5132        code,
5133        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
5134    )?;
5135    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
5136    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5137    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5138    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5139    writeln!(
5140        code,
5141        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
5142    )?;
5143    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
5144    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5145    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5146    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5147    writeln!(
5148        code,
5149        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5150    )?;
5151    writeln!(
5152        code,
5153        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5154    )?;
5155    writeln!(
5156        code,
5157        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5158    )?;
5159    writeln!(
5160        code,
5161        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5162    )?;
5163    writeln!(
5164        code,
5165        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
5166    )?;
5167    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5168    writeln!(
5169        code,
5170        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5171    )?;
5172    writeln!(
5173        code,
5174        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5175    )?;
5176    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5177    writeln!(code, "    }}")?;
5178    writeln!(code)?;
5179
5180    // dispatch_attention_offset
5181    writeln!(
5182        code,
5183        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
5184    )?;
5185    writeln!(
5186        code,
5187        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
5188    )?;
5189    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
5190    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5191    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5192    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5193    writeln!(
5194        code,
5195        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
5196    )?;
5197    writeln!(
5198        code,
5199        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5200    )?;
5201    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5202    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5203    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5204    writeln!(
5205        code,
5206        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5207    )?;
5208    writeln!(
5209        code,
5210        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5211    )?;
5212    writeln!(
5213        code,
5214        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5215    )?;
5216    writeln!(
5217        code,
5218        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5219    )?;
5220    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5221    writeln!(
5222        code,
5223        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5224    )?;
5225    writeln!(
5226        code,
5227        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5228    )?;
5229    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5230    writeln!(code, "    }}")?;
5231    writeln!(code)?;
5232
5233    // dispatch_silu_mul
5234    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
5235    writeln!(
5236        code,
5237        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5238    )?;
5239    writeln!(code, "        let count: u32 = n as u32;")?;
5240    writeln!(
5241        code,
5242        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5243    )?;
5244    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
5245    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
5246    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5247    writeln!(
5248        code,
5249        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5250    )?;
5251    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5252    writeln!(
5253        code,
5254        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5255    )?;
5256    writeln!(
5257        code,
5258        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5259    )?;
5260    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5261    writeln!(code, "    }}")?;
5262    writeln!(code)?;
5263
5264    // dispatch_silu_mul_fused
5265    writeln!(
5266        code,
5267        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5268    )?;
5269    writeln!(
5270        code,
5271        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
5272    )?;
5273    writeln!(
5274        code,
5275        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5276    )?;
5277    writeln!(code, "        let count: u32 = n as u32;")?;
5278    writeln!(
5279        code,
5280        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5281    )?;
5282    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5283    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5284    writeln!(
5285        code,
5286        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5287    )?;
5288    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5289    writeln!(
5290        code,
5291        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5292    )?;
5293    writeln!(
5294        code,
5295        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5296    )?;
5297    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5298    writeln!(code, "    }}")?;
5299    writeln!(code)?;
5300
5301    // dispatch_copy (simple src -> dst copy via compute kernel)
5302    writeln!(
5303        code,
5304        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5305    )?;
5306    writeln!(
5307        code,
5308        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5309    )?;
5310    writeln!(code, "        let n: u32 = count as u32;")?;
5311    writeln!(
5312        code,
5313        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5314    )?;
5315    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5316    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5317    writeln!(
5318        code,
5319        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5320    )?;
5321    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5322    writeln!(
5323        code,
5324        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5325    )?;
5326    writeln!(
5327        code,
5328        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5329    )?;
5330    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5331    writeln!(code, "    }}")?;
5332    writeln!(code)?;
5333
5334    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
5335    writeln!(
5336        code,
5337        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5338    )?;
5339    writeln!(
5340        code,
5341        "    /// Used for embedding table lookup (copy a specific row)."
5342    )?;
5343    writeln!(
5344        code,
5345        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5346    )?;
5347    writeln!(code, "        let off: u32 = src_offset as u32;")?;
5348    writeln!(code, "        let n: u32 = count as u32;")?;
5349    writeln!(
5350        code,
5351        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5352    )?;
5353    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5354    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5355    writeln!(
5356        code,
5357        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5358    )?;
5359    writeln!(
5360        code,
5361        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5362    )?;
5363    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5364    writeln!(
5365        code,
5366        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5367    )?;
5368    writeln!(
5369        code,
5370        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5371    )?;
5372    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5373    writeln!(code, "    }}")?;
5374    writeln!(code)?;
5375
5376    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
5377    writeln!(
5378        code,
5379        "    /// Dispatch copy from source at byte offset to destination at float offset."
5380    )?;
5381    writeln!(
5382        code,
5383        "    /// Used for KV cache updates from fused QKV buffer."
5384    )?;
5385    writeln!(
5386        code,
5387        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5388    )?;
5389    writeln!(code, "        let n: u32 = count as u32;")?;
5390    writeln!(
5391        code,
5392        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5393    )?;
5394    writeln!(
5395        code,
5396        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5397    )?;
5398    writeln!(
5399        code,
5400        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5401    )?;
5402    writeln!(
5403        code,
5404        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5405    )?;
5406    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5407    writeln!(
5408        code,
5409        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5410    )?;
5411    writeln!(
5412        code,
5413        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5414    )?;
5415    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5416    writeln!(code, "    }}")?;
5417    writeln!(code)?;
5418
5419    // dispatch_copy_from_offset_f16 (copy src[src_byte_offset..] -> f16 dst[dst_elem_offset..])
5420    writeln!(
5421        code,
5422        "    /// Dispatch f32 -> f16 copy from src at byte offset to half-typed dst at element offset."
5423    )?;
5424    writeln!(
5425        code,
5426        "    /// Used for single-token decode KV cache updates (f32 QKV buf -> f16 KV cache)."
5427    )?;
5428    writeln!(
5429        code,
5430        "    fn dispatch_copy_from_offset_f16(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_elem_offset: usize, count: usize) {{"
5431    )?;
5432    writeln!(code, "        let n: u32 = count as u32;")?;
5433    writeln!(
5434        code,
5435        "        enc.set_compute_pipeline_state(&self.copy_f32_to_f16_offset_pipeline);"
5436    )?;
5437    writeln!(
5438        code,
5439        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5440    )?;
5441    writeln!(
5442        code,
5443        "        enc.set_buffer(1, Some(dst), (dst_elem_offset * 2) as u64);"
5444    )?;
5445    writeln!(
5446        code,
5447        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5448    )?;
5449    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5450    writeln!(
5451        code,
5452        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5453    )?;
5454    writeln!(
5455        code,
5456        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5457    )?;
5458    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5459    writeln!(code, "    }}")?;
5460    writeln!(code)?;
5461
5462    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
5463    writeln!(
5464        code,
5465        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5466    )?;
5467    writeln!(
5468        code,
5469        "    /// Used for KV cache updates (write to a specific position in the cache)."
5470    )?;
5471    writeln!(
5472        code,
5473        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5474    )?;
5475    writeln!(code, "        let n: u32 = count as u32;")?;
5476    writeln!(
5477        code,
5478        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5479    )?;
5480    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5481    writeln!(
5482        code,
5483        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5484    )?;
5485    writeln!(
5486        code,
5487        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5488    )?;
5489    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5490    writeln!(
5491        code,
5492        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5493    )?;
5494    writeln!(
5495        code,
5496        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5497    )?;
5498    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5499    writeln!(code, "    }}")?;
5500    writeln!(code)?;
5501
5502    // dispatch_add_inplace (residual connection, no blit needed)
5503    writeln!(
5504        code,
5505        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5506    )?;
5507    writeln!(
5508        code,
5509        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5510    )?;
5511    writeln!(code, "        let count: u32 = n as u32;")?;
5512    writeln!(
5513        code,
5514        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5515    )?;
5516    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5517    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5518    writeln!(
5519        code,
5520        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5521    )?;
5522    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5523    writeln!(
5524        code,
5525        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5526    )?;
5527    writeln!(
5528        code,
5529        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5530    )?;
5531    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5532    writeln!(code, "    }}")?;
5533    writeln!(code)?;
5534
5535    // ── Batched prefill dispatch helpers ──
5536    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
5537    writeln!(code)?;
5538
5539    // dispatch_copy_embedding_batch
5540    writeln!(
5541        code,
5542        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
5543    )?;
5544    writeln!(
5545        code,
5546        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5547    )?;
5548    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
5549    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5550    writeln!(
5551        code,
5552        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5553    )?;
5554    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5555    writeln!(
5556        code,
5557        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5558    )?;
5559    writeln!(
5560        code,
5561        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5562    )?;
5563    writeln!(
5564        code,
5565        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5566    )?;
5567    writeln!(
5568        code,
5569        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5570    )?;
5571    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
5572    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5573    writeln!(
5574        code,
5575        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5576    )?;
5577    writeln!(
5578        code,
5579        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5580    )?;
5581    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5582    writeln!(code, "    }}")?;
5583    writeln!(code)?;
5584
5585    // dispatch_scale_buffer (Gemma-1 embedding scaling)
5586    writeln!(
5587        code,
5588        "    /// Multiply every element of `data` by `scale` in place."
5589    )?;
5590    writeln!(
5591        code,
5592        "    fn dispatch_scale_buffer(&self, enc: &ComputeCommandEncoderRef, data: &Buffer, scale: f32, count: usize) {{"
5593    )?;
5594    writeln!(code, "        let n: u32 = count as u32;")?;
5595    writeln!(
5596        code,
5597        "        enc.set_compute_pipeline_state(&self.scale_buffer_pipeline);"
5598    )?;
5599    writeln!(code, "        enc.set_buffer(0, Some(data), 0);")?;
5600    writeln!(
5601        code,
5602        "        enc.set_bytes(1, mem::size_of::<f32>() as u64, &scale as *const f32 as *const _);"
5603    )?;
5604    writeln!(
5605        code,
5606        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5607    )?;
5608    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5609    writeln!(
5610        code,
5611        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5612    )?;
5613    writeln!(
5614        code,
5615        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5616    )?;
5617    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5618    writeln!(code, "    }}")?;
5619    writeln!(code)?;
5620
5621    // dispatch_rms_norm_batch
5622    writeln!(
5623        code,
5624        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
5625    )?;
5626    writeln!(
5627        code,
5628        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5629    )?;
5630    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
5631    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
5632    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5633    writeln!(
5634        code,
5635        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5636    )?;
5637    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
5638    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
5639    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5640    writeln!(
5641        code,
5642        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5643    )?;
5644    writeln!(
5645        code,
5646        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5647    )?;
5648    writeln!(
5649        code,
5650        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5651    )?;
5652    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5653    writeln!(
5654        code,
5655        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5656    )?;
5657    writeln!(
5658        code,
5659        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5660    )?;
5661    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5662    writeln!(code, "    }}")?;
5663    writeln!(code)?;
5664
5665    // dispatch_matmul_batch (f32)
5666    writeln!(
5667        code,
5668        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5669    )?;
5670    writeln!(
5671        code,
5672        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5673    )?;
5674    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5675    writeln!(code, "        let r: u32 = rows as u32;")?;
5676    writeln!(code, "        let c: u32 = cols as u32;")?;
5677    writeln!(
5678        code,
5679        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5680    )?;
5681    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5682    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5683    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5684    writeln!(
5685        code,
5686        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5687    )?;
5688    writeln!(
5689        code,
5690        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5691    )?;
5692    writeln!(
5693        code,
5694        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5695    )?;
5696    writeln!(
5697        code,
5698        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
5699    )?;
5700    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5701    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5702    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5703    writeln!(
5704        code,
5705        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5706    )?;
5707    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5708    writeln!(code, "    }}")?;
5709    writeln!(code)?;
5710
5711    // dispatch_matmul_q8_batch
5712    writeln!(
5713        code,
5714        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5715    )?;
5716    writeln!(code, "    ///")?;
5717    writeln!(
5718        code,
5719        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5720    )?;
5721    writeln!(
5722        code,
5723        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5724    )?;
5725    writeln!(
5726        code,
5727        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5728    )?;
5729    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5730    writeln!(code, "        let r: u32 = rows as u32;")?;
5731    writeln!(code, "        let c: u32 = cols as u32;")?;
5732    writeln!(
5733        code,
5734        "        // Tile sizes must match the Metal shader constants."
5735    )?;
5736    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
5737    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
5738    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
5739    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
5740    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
5741    writeln!(
5742        code,
5743        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
5744    )?;
5745    writeln!(
5746        code,
5747        "        // Prefer the large 32×32 tile when the problem supports it — halves"
5748    )?;
5749    writeln!(
5750        code,
5751        "        // dispatch count and reuses each weight load across 32 tokens."
5752    )?;
5753    writeln!(
5754        code,
5755        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5756    )?;
5757    writeln!(
5758        code,
5759        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5760    )?;
5761    writeln!(
5762        code,
5763        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
5764    )?;
5765    writeln!(
5766        code,
5767        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5768    )?;
5769    writeln!(
5770        code,
5771        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
5772    )?;
5773    writeln!(
5774        code,
5775        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5776    )?;
5777    writeln!(
5778        code,
5779        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5780    )?;
5781    writeln!(
5782        code,
5783        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5784    )?;
5785    writeln!(
5786        code,
5787        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5788    )?;
5789    writeln!(code, "            let use_h4 = cols >= 2048;")?;
5790    writeln!(code, "            let pipe = if use_h4 {{")?;
5791    writeln!(code, "                if num_tokens >= 256 {{")?;
5792    writeln!(
5793        code,
5794        "                    &self.matmul_q8_mma32_hh4_pipeline"
5795    )?;
5796    writeln!(code, "                }} else {{")?;
5797    writeln!(
5798        code,
5799        "                    &self.matmul_q8_mma32_h4_pipeline"
5800    )?;
5801    writeln!(code, "                }}")?;
5802    writeln!(code, "            }} else {{")?;
5803    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
5804    writeln!(code, "            }};")?;
5805    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
5806    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5807    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5808    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5809    writeln!(
5810        code,
5811        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5812    )?;
5813    writeln!(
5814        code,
5815        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5816    )?;
5817    writeln!(
5818        code,
5819        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5820    )?;
5821    writeln!(code, "            let row_tgs = rows / MMA32_ROW_TILE;")?;
5822    writeln!(
5823        code,
5824        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5825    )?;
5826    writeln!(
5827        code,
5828        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5829    )?;
5830    writeln!(
5831        code,
5832        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5833    )?;
5834    writeln!(
5835        code,
5836        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5837    )?;
5838    writeln!(
5839        code,
5840        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5841    )?;
5842    writeln!(
5843        code,
5844        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5845    )?;
5846    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5847    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5848    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5849    writeln!(
5850        code,
5851        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5852    )?;
5853    writeln!(
5854        code,
5855        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5856    )?;
5857    writeln!(
5858        code,
5859        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5860    )?;
5861    writeln!(code, "            let row_tgs = rows / MMA_ROW_TILE;")?;
5862    writeln!(
5863        code,
5864        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5865    )?;
5866    writeln!(
5867        code,
5868        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5869    )?;
5870    writeln!(
5871        code,
5872        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5873    )?;
5874    writeln!(
5875        code,
5876        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5877    )?;
5878    // Route to gemm for cols > 8192 too — the per-token matmul_q8_batch
5879    // caches the input vector in a VEC_TILE_SIZE (8192-float) threadgroup
5880    // array, which overflows for cols = intermediate_size > 8192 (e.g.
5881    // Gemma-2B's 16384-wide down-proj). Gemm reads inputs directly from
5882    // device memory so it handles any cols width.
5883    writeln!(
5884        code,
5885        "        }} else if num_tokens >= TOKENS_PER_TG_Q8 || cols > 8192 {{"
5886    )?;
5887    writeln!(
5888        code,
5889        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5890    )?;
5891    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5892    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5893    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5894    writeln!(
5895        code,
5896        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5897    )?;
5898    writeln!(
5899        code,
5900        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5901    )?;
5902    writeln!(
5903        code,
5904        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5905    )?;
5906    writeln!(
5907        code,
5908        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5909    )?;
5910    writeln!(
5911        code,
5912        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5913    )?;
5914    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5915    writeln!(
5916        code,
5917        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5918    )?;
5919    writeln!(
5920        code,
5921        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5922    )?;
5923    writeln!(code, "        }} else {{")?;
5924    writeln!(
5925        code,
5926        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5927    )?;
5928    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5929    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5930    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5931    writeln!(
5932        code,
5933        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5934    )?;
5935    writeln!(
5936        code,
5937        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5938    )?;
5939    writeln!(
5940        code,
5941        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5942    )?;
5943    writeln!(
5944        code,
5945        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5946    )?;
5947    writeln!(
5948        code,
5949        "            let num_tg = (row_tgs * num_tokens) as u64;"
5950    )?;
5951    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5952    writeln!(
5953        code,
5954        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5955    )?;
5956    writeln!(
5957        code,
5958        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5959    )?;
5960    writeln!(code, "        }}")?;
5961    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5962    writeln!(code, "    }}")?;
5963    writeln!(code)?;
5964
5965    // dispatch_matmul_q4_batch
5966    writeln!(
5967        code,
5968        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5969    )?;
5970    writeln!(
5971        code,
5972        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5973    )?;
5974    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5975    writeln!(code, "        let r: u32 = rows as u32;")?;
5976    writeln!(code, "        let c: u32 = cols as u32;")?;
5977    writeln!(
5978        code,
5979        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5980    )?;
5981    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5982    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5983    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5984    writeln!(
5985        code,
5986        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5987    )?;
5988    writeln!(
5989        code,
5990        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5991    )?;
5992    writeln!(
5993        code,
5994        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5995    )?;
5996    writeln!(
5997        code,
5998        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5999    )?;
6000    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
6001    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6002    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
6003    writeln!(
6004        code,
6005        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6006    )?;
6007    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6008    writeln!(code, "    }}")?;
6009    writeln!(code)?;
6010
6011    // dispatch_add_bias_batch — Qwen2 QKV bias broadcast-add after fused qkv matmul.
6012    if config.qkv_bias {
6013        writeln!(
6014            code,
6015            "    /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
6016        )?;
6017        writeln!(
6018            code,
6019            "    fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
6020        )?;
6021        writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
6022        writeln!(code, "        let r: u32 = rows as u32;")?;
6023        writeln!(
6024            code,
6025            "        enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
6026        )?;
6027        writeln!(code, "        enc.set_buffer(0, Some(out), 0);")?;
6028        writeln!(code, "        enc.set_buffer(1, Some(bias), 0);")?;
6029        writeln!(
6030            code,
6031            "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
6032        )?;
6033        writeln!(
6034            code,
6035            "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
6036        )?;
6037        writeln!(code, "        let total = (num_tokens * rows) as u64;")?;
6038        writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6039        writeln!(
6040            code,
6041            "        let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
6042        )?;
6043        writeln!(
6044            code,
6045            "        enc.dispatch_thread_groups(grid_size, tg_size);"
6046        )?;
6047        writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6048        writeln!(code, "    }}")?;
6049        writeln!(code)?;
6050    }
6051
6052    // dispatch_rope_batch
6053    writeln!(
6054        code,
6055        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
6056    )?;
6057    writeln!(
6058        code,
6059        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
6060    )?;
6061    writeln!(
6062        code,
6063        "    /// `row_stride` is the number of floats per token row in the batch buffer."
6064    )?;
6065    writeln!(
6066        code,
6067        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
6068    )?;
6069    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
6070    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
6071    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
6072    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
6073    writeln!(
6074        code,
6075        "        let pairs_per_token = num_heads * (head_dim / 2);"
6076    )?;
6077    writeln!(
6078        code,
6079        "        let total_pairs = num_tokens * pairs_per_token;"
6080    )?;
6081    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
6082    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
6083    // we need to pass the buffer at the right byte offset for each token's data.
6084    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
6085    // but our layout is data[token * row_stride + data_float_offset + ...].
6086    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
6087    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
6088    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
6089    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
6090    //
6091    // Simplest approach: use the single-token rope kernel for each token in a loop.
6092    // This is still efficient because we're dispatching all within the same command encoder.
6093    writeln!(
6094        code,
6095        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
6096    )?;
6097    writeln!(code, "        for t in 0..num_tokens {{")?;
6098    writeln!(
6099        code,
6100        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
6101    )?;
6102    writeln!(
6103        code,
6104        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
6105    )?;
6106    writeln!(
6107        code,
6108        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
6109    )?;
6110    writeln!(
6111        code,
6112        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
6113    )?;
6114    writeln!(
6115        code,
6116        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6117    )?;
6118    writeln!(
6119        code,
6120        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6121    )?;
6122    writeln!(
6123        code,
6124        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
6125    )?;
6126    writeln!(
6127        code,
6128        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6129    )?;
6130    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
6131    writeln!(
6132        code,
6133        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
6134    )?;
6135    writeln!(
6136        code,
6137        "            enc.dispatch_thread_groups(grid_size, tg_size);"
6138    )?;
6139    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6140    writeln!(code, "        }}")?;
6141    writeln!(code, "    }}")?;
6142    writeln!(code)?;
6143
6144    // dispatch_silu_mul_fused_batch
6145    writeln!(
6146        code,
6147        "    /// Dispatch batched fused SiLU-multiply for M tokens."
6148    )?;
6149    writeln!(
6150        code,
6151        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
6152    )?;
6153    writeln!(code, "        let count: u32 = n as u32;")?;
6154    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
6155    writeln!(
6156        code,
6157        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
6158    )?;
6159    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
6160    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
6161    writeln!(
6162        code,
6163        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
6164    )?;
6165    writeln!(
6166        code,
6167        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
6168    )?;
6169    writeln!(code, "        let total = num_tokens * n;")?;
6170    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6171    writeln!(
6172        code,
6173        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6174    )?;
6175    writeln!(
6176        code,
6177        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6178    )?;
6179    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6180    writeln!(code, "    }}")?;
6181    writeln!(code)?;
6182
6183    // dispatch_add_inplace_batch_n (add n elements in-place)
6184    writeln!(
6185        code,
6186        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
6187    )?;
6188    writeln!(
6189        code,
6190        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
6191    )?;
6192    writeln!(code, "        let count: u32 = total_n as u32;")?;
6193    writeln!(
6194        code,
6195        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
6196    )?;
6197    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
6198    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
6199    writeln!(
6200        code,
6201        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
6202    )?;
6203    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6204    writeln!(
6205        code,
6206        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
6207    )?;
6208    writeln!(
6209        code,
6210        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6211    )?;
6212    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6213    writeln!(code, "    }}")?;
6214    writeln!(code)?;
6215
6216    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
6217    writeln!(
6218        code,
6219        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
6220    )?;
6221    writeln!(
6222        code,
6223        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
6224    )?;
6225    writeln!(code, "        let n: u32 = count as u32;")?;
6226    writeln!(
6227        code,
6228        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6229    )?;
6230    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6231    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
6232    writeln!(
6233        code,
6234        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6235    )?;
6236    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6237    writeln!(
6238        code,
6239        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6240    )?;
6241    writeln!(
6242        code,
6243        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6244    )?;
6245    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6246    writeln!(code, "    }}")?;
6247    writeln!(code)?;
6248
6249    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
6250    writeln!(
6251        code,
6252        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
6253    )?;
6254    writeln!(
6255        code,
6256        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6257    )?;
6258    writeln!(code, "        let n: u32 = count as u32;")?;
6259    writeln!(
6260        code,
6261        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6262    )?;
6263    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6264    writeln!(
6265        code,
6266        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6267    )?;
6268    writeln!(
6269        code,
6270        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6271    )?;
6272    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6273    writeln!(
6274        code,
6275        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6276    )?;
6277    writeln!(
6278        code,
6279        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6280    )?;
6281    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6282    writeln!(code, "    }}")?;
6283    writeln!(code)?;
6284
6285    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
6286    writeln!(
6287        code,
6288        "    /// Copy from src at byte offset to dst at float offset."
6289    )?;
6290    writeln!(
6291        code,
6292        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6293    )?;
6294    writeln!(code, "        let n: u32 = count as u32;")?;
6295    writeln!(
6296        code,
6297        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6298    )?;
6299    writeln!(
6300        code,
6301        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6302    )?;
6303    writeln!(
6304        code,
6305        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6306    )?;
6307    writeln!(
6308        code,
6309        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6310    )?;
6311    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6312    writeln!(
6313        code,
6314        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6315    )?;
6316    writeln!(
6317        code,
6318        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6319    )?;
6320    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6321    writeln!(code, "    }}")?;
6322    writeln!(code)?;
6323
6324    // dispatch_copy_kv_batch
6325    writeln!(
6326        code,
6327        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6328    )?;
6329    writeln!(
6330        code,
6331        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6332    )?;
6333    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6334    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6335    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6336    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6337    writeln!(code, "        let so: u32 = src_offset as u32;")?;
6338    writeln!(
6339        code,
6340        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6341    )?;
6342    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6343    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
6344    writeln!(
6345        code,
6346        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6347    )?;
6348    writeln!(
6349        code,
6350        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6351    )?;
6352    writeln!(
6353        code,
6354        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6355    )?;
6356    writeln!(
6357        code,
6358        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6359    )?;
6360    writeln!(
6361        code,
6362        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6363    )?;
6364    writeln!(code, "        let total = num_tokens * kv_dim;")?;
6365    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6366    writeln!(
6367        code,
6368        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6369    )?;
6370    writeln!(
6371        code,
6372        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6373    )?;
6374    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6375    writeln!(code, "    }}")?;
6376    writeln!(code)?;
6377
6378    // dispatch_attention_batch
6379    writeln!(
6380        code,
6381        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
6382    )?;
6383    writeln!(
6384        code,
6385        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6386    )?;
6387    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6388    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6389    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
6390    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6391    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6392    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
6393    // Attention kernel selection:
6394    //   * Legacy `attention_batch` materializes scores[4096] in threadgroup memory
6395    //     and uses scalar simdgroup reductions.  Fast at short seq_len, no MMA.
6396    //   * `attention_flash_batch` streams K/V with online softmax; no seq cap,
6397    //     scalar math, ~7-14 % slower than legacy at long contexts (no MMA).
6398    //   * `attention_mma_flash_batch` adds hardware simdgroup_matrix<half, 8, 8>
6399    //     MMA for both Q·K^T and P·V, processing Q_BLOCK=8 tokens per TG.
6400    //     Default path when HEAD_DIM ≤ 128 and num_tokens ≥ 8 (verified on
6401    //     Llama, Qwen2.5, Phi-3).  Set FORGE_MMA_ATTN=0 to force legacy.
6402    writeln!(code, "        let max_seq = base_pos + num_tokens;")?;
6403    writeln!(code, "        let _ = max_seq;")?;
6404    writeln!(
6405        code,
6406        "        let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6407    )?;
6408    writeln!(code, "            .map(|v| v == \"0\").unwrap_or(false);")?;
6409    writeln!(
6410        code,
6411        "        let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6412    )?;
6413    writeln!(code, "        if use_mma_flash {{")?;
6414    writeln!(
6415        code,
6416        "            let pipe = &self.attention_mma_flash_batch_pipeline;"
6417    )?;
6418    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
6419    writeln!(code, "            enc.set_buffer(0, Some(q_buf), 0);")?;
6420    writeln!(code, "            enc.set_buffer(1, Some(k_cache), 0);")?;
6421    writeln!(code, "            enc.set_buffer(2, Some(v_cache), 0);")?;
6422    writeln!(code, "            enc.set_buffer(3, Some(output), 0);")?;
6423    writeln!(
6424        code,
6425        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6426    )?;
6427    writeln!(
6428        code,
6429        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6430    )?;
6431    writeln!(
6432        code,
6433        "            enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6434    )?;
6435    writeln!(
6436        code,
6437        "            enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6438    )?;
6439    writeln!(
6440        code,
6441        "            enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6442    )?;
6443    writeln!(
6444        code,
6445        "            enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6446    )?;
6447    writeln!(
6448        code,
6449        "            // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6450    )?;
6451    writeln!(code, "            let tg_size = MTLSize::new(128, 1, 1);")?;
6452    writeln!(
6453        code,
6454        "            let q_blocks = ((num_tokens + 7) / 8) as u64;"
6455    )?;
6456    writeln!(
6457        code,
6458        "            let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6459    )?;
6460    writeln!(
6461        code,
6462        "            enc.dispatch_thread_groups(grid_size, tg_size);"
6463    )?;
6464    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6465    writeln!(code, "            return;")?;
6466    writeln!(code, "        }}")?;
6467    writeln!(code, "        let pipe = &self.attention_batch_pipeline;")?;
6468    writeln!(code, "        enc.set_compute_pipeline_state(pipe);")?;
6469    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
6470    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
6471    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
6472    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
6473    writeln!(
6474        code,
6475        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6476    )?;
6477    writeln!(
6478        code,
6479        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6480    )?;
6481    writeln!(
6482        code,
6483        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6484    )?;
6485    writeln!(
6486        code,
6487        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6488    )?;
6489    writeln!(
6490        code,
6491        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6492    )?;
6493    writeln!(
6494        code,
6495        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6496    )?;
6497    writeln!(
6498        code,
6499        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6500    )?;
6501    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6502    writeln!(
6503        code,
6504        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6505    )?;
6506    writeln!(
6507        code,
6508        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6509    )?;
6510    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6511    writeln!(code, "    }}")?;
6512    writeln!(code)?;
6513
6514    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
6515    writeln!(
6516        code,
6517        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6518    )?;
6519    writeln!(
6520        code,
6521        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6522    )?;
6523    writeln!(
6524        code,
6525        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6526    )?;
6527    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6528    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6529    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
6530    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6531    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6532    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
6533    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
6534    writeln!(
6535        code,
6536        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6537    )?;
6538    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
6539    writeln!(
6540        code,
6541        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6542    )?;
6543    writeln!(
6544        code,
6545        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6546    )?;
6547    writeln!(
6548        code,
6549        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6550    )?;
6551    writeln!(
6552        code,
6553        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6554    )?;
6555    writeln!(
6556        code,
6557        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6558    )?;
6559    writeln!(
6560        code,
6561        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6562    )?;
6563    writeln!(
6564        code,
6565        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6566    )?;
6567    writeln!(
6568        code,
6569        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6570    )?;
6571    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6572    writeln!(
6573        code,
6574        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6575    )?;
6576    writeln!(
6577        code,
6578        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6579    )?;
6580    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6581    writeln!(code, "    }}")?;
6582    writeln!(code)?;
6583
6584    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
6585    writeln!(
6586        code,
6587        "    /// Dispatch fused K+V cache copy in one kernel launch."
6588    )?;
6589    writeln!(
6590        code,
6591        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6592    )?;
6593    writeln!(
6594        code,
6595        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6596    )?;
6597    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6598    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6599    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6600    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6601    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
6602    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
6603    writeln!(
6604        code,
6605        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6606    )?;
6607    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6608    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
6609    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
6610    writeln!(
6611        code,
6612        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6613    )?;
6614    writeln!(
6615        code,
6616        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6617    )?;
6618    writeln!(
6619        code,
6620        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6621    )?;
6622    writeln!(
6623        code,
6624        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6625    )?;
6626    writeln!(
6627        code,
6628        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6629    )?;
6630    writeln!(
6631        code,
6632        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6633    )?;
6634    writeln!(
6635        code,
6636        "        let total = num_tokens * kv_dim * 2;  // K + V"
6637    )?;
6638    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6639    writeln!(
6640        code,
6641        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6642    )?;
6643    writeln!(
6644        code,
6645        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6646    )?;
6647    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6648    writeln!(code, "    }}")?;
6649
6650    writeln!(code, "}}")?;
6651    writeln!(code)?;
6652
6653    Ok(())
6654}
6655
6656fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6657    writeln!(
6658        code,
6659        "// ── Helper functions ──────────────────────────────────"
6660    )?;
6661    writeln!(code)?;
6662    writeln!(
6663        code,
6664        "/// Create a compute pipeline from a named function in the Metal library."
6665    )?;
6666    writeln!(
6667        code,
6668        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6669    )?;
6670    writeln!(
6671        code,
6672        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6673    )?;
6674    writeln!(
6675        code,
6676        "    device.new_compute_pipeline_state_with_function(&func)"
6677    )?;
6678    writeln!(
6679        code,
6680        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6681    )?;
6682    writeln!(code, "}}")?;
6683    writeln!(code)?;
6684
6685    Ok(())
6686}
6687
6688// ---------------------------------------------------------------------------
6689// main.rs generation
6690// ---------------------------------------------------------------------------
6691
6692fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6693    let _sanitized = sanitize_name(model_name);
6694    let _vocab = config.vocab_size;
6695
6696    let mut code = String::with_capacity(16 * 1024);
6697    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6698    writeln!(
6699        code,
6700        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6701    )?;
6702    writeln!(code)?;
6703    writeln!(code, "mod model;")?;
6704    writeln!(code)?;
6705    writeln!(code, "use std::io::Write;")?;
6706    writeln!(code, "use std::time::Instant;")?;
6707    writeln!(code, "use serde::Deserialize;")?;
6708    writeln!(code)?;
6709
6710    // -- main function --
6711    writeln!(code, "fn main() {{")?;
6712    writeln!(
6713        code,
6714        "    let args: Vec<String> = std::env::args().collect();"
6715    )?;
6716    writeln!(code)?;
6717    writeln!(
6718        code,
6719        "    // Detect --serve mode (only requires weights + tokenizer)"
6720    )?;
6721    writeln!(
6722        code,
6723        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
6724    )?;
6725    writeln!(code)?;
6726    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
6727    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6728    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6729    writeln!(code, "        std::process::exit(1);")?;
6730    writeln!(code, "    }}")?;
6731    writeln!(code)?;
6732    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
6733    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6734    writeln!(code, "        std::process::exit(1);")?;
6735    writeln!(code, "    }}")?;
6736    writeln!(code)?;
6737    writeln!(code, "    let weights_path = &args[1];")?;
6738    writeln!(code, "    let tokenizer_path = &args[2];")?;
6739    writeln!(code)?;
6740    writeln!(code, "    // Parse optional flags")?;
6741    writeln!(code, "    let mut max_tokens: usize = 128;")?;
6742    writeln!(code, "    let mut port: u16 = 8080;")?;
6743    writeln!(
6744        code,
6745        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6746    )?;
6747    writeln!(
6748        code,
6749        "    let profile = args.iter().any(|a| a == \"--profile\");"
6750    )?;
6751    writeln!(code, "    let mut i = 3;")?;
6752    writeln!(code, "    while i < args.len() {{")?;
6753    writeln!(
6754        code,
6755        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6756    )?;
6757    writeln!(
6758        code,
6759        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
6760    )?;
6761    writeln!(code, "            i += 2;")?;
6762    writeln!(
6763        code,
6764        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6765    )?;
6766    writeln!(
6767        code,
6768        "            port = args[i + 1].parse().unwrap_or(8080);"
6769    )?;
6770    writeln!(code, "            i += 2;")?;
6771    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
6772    writeln!(code, "            i += 1;")?;
6773    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
6774    writeln!(code, "            i += 1;")?;
6775    writeln!(code, "        }} else {{")?;
6776    writeln!(code, "            i += 1;")?;
6777    writeln!(code, "        }}")?;
6778    writeln!(code, "    }}")?;
6779    writeln!(code)?;
6780
6781    // -- load model (shared by both modes) --
6782    writeln!(
6783        code,
6784        "    // Memory-map weights for zero-copy loading on Apple Silicon"
6785    )?;
6786    writeln!(
6787        code,
6788        "    let weights_file = std::fs::File::open(weights_path)"
6789    )?;
6790    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6791    writeln!(
6792        code,
6793        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6794    )?;
6795    writeln!(code)?;
6796    writeln!(code, "    // Load tokenizer")?;
6797    writeln!(
6798        code,
6799        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6800    )?;
6801    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6802    writeln!(code)?;
6803    writeln!(code, "    // Create Metal model")?;
6804    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
6805    writeln!(
6806        code,
6807        "    let mut model = model::MetalModel::new(&weights_mmap);"
6808    )?;
6809    writeln!(code)?;
6810
6811    // -- branch: serve vs CLI --
6812    writeln!(code, "    if serve_mode {{")?;
6813    writeln!(code, "        serve(model, tokenizer, port);")?;
6814    writeln!(code, "    }} else {{")?;
6815    writeln!(code, "        let prompt = &args[3];")?;
6816    writeln!(
6817        code,
6818        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6819    )?;
6820    writeln!(code, "    }}")?;
6821    writeln!(code, "}}")?;
6822    writeln!(code)?;
6823
6824    // -- cli_mode function --
6825    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6826    writeln!(code, "    // Tokenize prompt")?;
6827    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
6828    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6829    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6830    writeln!(code)?;
6831    writeln!(
6832        code,
6833        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6834    )?;
6835    writeln!(
6836        code,
6837        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
6838    )?;
6839    writeln!(
6840        code,
6841        "    // The last token uses synchronous forward() to get logits."
6842    )?;
6843    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
6844    writeln!(code, "    let prefill_start = Instant::now();")?;
6845    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
6846    writeln!(
6847        code,
6848        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6849    )?;
6850    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
6851    writeln!(code, "    }} else {{")?;
6852    writeln!(code, "        model.forward(prompt_tokens[0])")?;
6853    writeln!(code, "    }};")?;
6854    writeln!(
6855        code,
6856        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6857    )?;
6858    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
6859    writeln!(
6860        code,
6861        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6862    )?;
6863    writeln!(
6864        code,
6865        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6866    )?;
6867    writeln!(code)?;
6868    writeln!(code, "    // Generate tokens")?;
6869    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6870    writeln!(code, "    let gen_start = Instant::now();")?;
6871    writeln!(code, "    let mut generated_count: usize = 0;")?;
6872    writeln!(code)?;
6873    writeln!(code, "    for _ in 0..max_tokens {{")?;
6874    writeln!(
6875        code,
6876        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6877    )?;
6878    writeln!(code, "            if !quiet {{")?;
6879    writeln!(code, "                print!(\"{{}}\", text);")?;
6880    writeln!(code, "                std::io::stdout().flush().ok();")?;
6881    writeln!(code, "            }}")?;
6882    writeln!(code, "        }}")?;
6883    writeln!(code, "        generated_count += 1;")?;
6884    writeln!(code)?;
6885    writeln!(
6886        code,
6887        "        // Use profiling forward for first token when --profile is set"
6888    )?;
6889    writeln!(
6890        code,
6891        "        let logits = if profile && generated_count == 1 {{"
6892    )?;
6893    writeln!(code, "            model.forward_profile(next_token)")?;
6894    writeln!(code, "        }} else {{")?;
6895    writeln!(code, "            model.forward(next_token)")?;
6896    writeln!(code, "        }};")?;
6897    writeln!(code, "        next_token = argmax(&logits);")?;
6898    writeln!(code)?;
6899    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
6900    writeln!(code, "        if next_token == 2 {{")?;
6901    writeln!(code, "            break;")?;
6902    writeln!(code, "        }}")?;
6903    writeln!(code)?;
6904    writeln!(
6905        code,
6906        "        // Yield between tokens to reduce sustained GPU thermal load."
6907    )?;
6908    writeln!(
6909        code,
6910        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6911    )?;
6912    writeln!(
6913        code,
6914        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6915    )?;
6916    writeln!(
6917        code,
6918        "        // briefly, providing a micro-break that helps sustain peak throughput."
6919    )?;
6920    writeln!(code, "        std::thread::yield_now();")?;
6921    writeln!(code, "    }}")?;
6922    writeln!(code, "    if !quiet {{")?;
6923    writeln!(code, "        println!();")?;
6924    writeln!(code, "    }}")?;
6925    writeln!(
6926        code,
6927        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6928    )?;
6929    writeln!(
6930        code,
6931        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6932    )?;
6933    writeln!(
6934        code,
6935        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6936    )?;
6937    writeln!(code, "}}")?;
6938    writeln!(code)?;
6939
6940    // -- argmax helper --
6941    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6942    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6943    writeln!(code, "    logits.iter()")?;
6944    writeln!(code, "        .enumerate()")?;
6945    writeln!(
6946        code,
6947        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6948    )?;
6949    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6950    writeln!(code, "        .unwrap_or(0)")?;
6951    writeln!(code, "}}")?;
6952    writeln!(code)?;
6953
6954    // -- Request/Response types for OpenAI API --
6955    writeln!(
6956        code,
6957        "// -----------------------------------------------------------------------"
6958    )?;
6959    writeln!(code, "// OpenAI-compatible API server")?;
6960    writeln!(
6961        code,
6962        "// -----------------------------------------------------------------------"
6963    )?;
6964    writeln!(code)?;
6965    writeln!(code, "#[derive(Deserialize)]")?;
6966    writeln!(code, "struct ChatRequest {{")?;
6967    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6968    writeln!(code, "    #[serde(default)]")?;
6969    writeln!(code, "    stream: Option<bool>,")?;
6970    writeln!(code, "    #[serde(default)]")?;
6971    writeln!(code, "    max_tokens: Option<usize>,")?;
6972    writeln!(code, "    #[serde(default)]")?;
6973    writeln!(code, "    temperature: Option<f32>,")?;
6974    writeln!(code, "    #[serde(default)]")?;
6975    writeln!(code, "    model: Option<String>,")?;
6976    writeln!(code, "}}")?;
6977    writeln!(code)?;
6978    writeln!(code, "#[derive(Deserialize)]")?;
6979    writeln!(code, "struct ChatMessage {{")?;
6980    writeln!(code, "    role: String,")?;
6981    writeln!(code, "    content: String,")?;
6982    writeln!(code, "}}")?;
6983    writeln!(code)?;
6984
6985    // -- format_chat_messages --
6986    writeln!(
6987        code,
6988        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6989    )?;
6990    writeln!(code, "    let mut prompt = String::new();")?;
6991    writeln!(code, "    for msg in messages {{")?;
6992    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6993    writeln!(code, "    }}")?;
6994    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6995    writeln!(code, "    prompt")?;
6996    writeln!(code, "}}")?;
6997    writeln!(code)?;
6998
6999    // -- prefill helper --
7000    writeln!(
7001        code,
7002        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
7003    )?;
7004    writeln!(code, "    let len = tokens.len();")?;
7005    writeln!(code, "    if len > 1 {{")?;
7006    writeln!(
7007        code,
7008        "        model.forward_prefill_batch(&tokens[..len - 1]);"
7009    )?;
7010    writeln!(code, "    }}")?;
7011    writeln!(code, "    model.forward(tokens[len - 1])")?;
7012    writeln!(code, "}}")?;
7013    writeln!(code)?;
7014
7015    // -- serve function --
7016    writeln!(
7017        code,
7018        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
7019    )?;
7020    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
7021    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
7022    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
7023    writeln!(
7024        code,
7025        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
7026    )?;
7027    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
7028    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
7029    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
7030    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
7031    writeln!(code)?;
7032    writeln!(code, "    for request in server.incoming_requests() {{")?;
7033    writeln!(code, "        let method = request.method().to_string();")?;
7034    writeln!(code, "        let url = request.url().to_string();")?;
7035    writeln!(code)?;
7036    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
7037
7038    // -- POST /v1/chat/completions --
7039    writeln!(
7040        code,
7041        "            (\"POST\", \"/v1/chat/completions\") => {{"
7042    )?;
7043    writeln!(
7044        code,
7045        "                handle_chat_completion(&mut model, &tokenizer, request);"
7046    )?;
7047    writeln!(code, "            }}")?;
7048
7049    // -- GET /v1/models --
7050    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
7051    writeln!(code, "                let body = serde_json::json!({{")?;
7052    writeln!(code, "                    \"object\": \"list\",")?;
7053    writeln!(code, "                    \"data\": [{{")?;
7054    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
7055    writeln!(code, "                        \"object\": \"model\",")?;
7056    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
7057    writeln!(code, "                    }}]")?;
7058    writeln!(code, "                }});")?;
7059    writeln!(
7060        code,
7061        "                let resp = tiny_http::Response::from_string(body.to_string())"
7062    )?;
7063    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7064    writeln!(code, "                request.respond(resp).ok();")?;
7065    writeln!(code, "            }}")?;
7066
7067    // -- GET /health --
7068    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
7069    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
7070    writeln!(code, "                request.respond(resp).ok();")?;
7071    writeln!(code, "            }}")?;
7072
7073    // -- 404 --
7074    writeln!(code, "            _ => {{")?;
7075    writeln!(
7076        code,
7077        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
7078    )?;
7079    writeln!(code, "                    .with_status_code(404);")?;
7080    writeln!(code, "                request.respond(resp).ok();")?;
7081    writeln!(code, "            }}")?;
7082    writeln!(code, "        }}")?;
7083    writeln!(code, "    }}")?;
7084    writeln!(code, "}}")?;
7085    writeln!(code)?;
7086
7087    // -- handle_chat_completion --
7088    writeln!(code, "fn handle_chat_completion(")?;
7089    writeln!(code, "    model: &mut model::MetalModel,")?;
7090    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
7091    writeln!(code, "    mut request: tiny_http::Request,")?;
7092    writeln!(code, ") {{")?;
7093    writeln!(code, "    // Read request body")?;
7094    writeln!(code, "    let mut body = String::new();")?;
7095    writeln!(
7096        code,
7097        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
7098    )?;
7099    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
7100    writeln!(code, "            .with_status_code(400);")?;
7101    writeln!(code, "        request.respond(resp).ok();")?;
7102    writeln!(code, "        return;")?;
7103    writeln!(code, "    }}")?;
7104    writeln!(code)?;
7105    writeln!(code, "    // Parse JSON")?;
7106    writeln!(
7107        code,
7108        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
7109    )?;
7110    writeln!(code, "        Ok(r) => r,")?;
7111    writeln!(code, "        Err(e) => {{")?;
7112    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
7113    writeln!(code, "                .with_status_code(400);")?;
7114    writeln!(code, "            request.respond(resp).ok();")?;
7115    writeln!(code, "            return;")?;
7116    writeln!(code, "        }}")?;
7117    writeln!(code, "    }};")?;
7118    writeln!(code)?;
7119    writeln!(
7120        code,
7121        "    let prompt = format_chat_messages(&req.messages);"
7122    )?;
7123    writeln!(
7124        code,
7125        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
7126    )?;
7127    writeln!(code, "        Ok(e) => e,")?;
7128    writeln!(code, "        Err(e) => {{")?;
7129    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
7130    writeln!(code, "                .with_status_code(500);")?;
7131    writeln!(code, "            request.respond(resp).ok();")?;
7132    writeln!(code, "            return;")?;
7133    writeln!(code, "        }}")?;
7134    writeln!(code, "    }};")?;
7135    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
7136    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
7137    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
7138    writeln!(
7139        code,
7140        "    let _temperature = req.temperature.unwrap_or(1.0);"
7141    )?;
7142    writeln!(code)?;
7143
7144    // -- Reset KV cache for each request --
7145    writeln!(code, "    model.reset();")?;
7146    writeln!(code)?;
7147
7148    // -- Prefill with timing --
7149    writeln!(code, "    let prefill_start = Instant::now();")?;
7150    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
7151    writeln!(
7152        code,
7153        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
7154    )?;
7155    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
7156    writeln!(code, "    let mut next_token = argmax(&logits);")?;
7157    writeln!(code)?;
7158
7159    writeln!(code, "    if stream {{")?;
7160
7161    // -- SSE streaming response --
7162    writeln!(
7163        code,
7164        "        // SSE streaming: generate tokens and build SSE body"
7165    )?;
7166    writeln!(code, "        let gen_start = Instant::now();")?;
7167    writeln!(code, "        let mut generated_count: usize = 0;")?;
7168    writeln!(code, "        let mut sse_body = String::new();")?;
7169    writeln!(code, "        for _ in 0..max_tokens {{")?;
7170    writeln!(code, "            if next_token == 2 {{ break; }}")?;
7171    writeln!(
7172        code,
7173        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7174    )?;
7175    writeln!(
7176        code,
7177        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
7178    )?;
7179    writeln!(
7180        code,
7181        "                // escaped includes surrounding quotes, strip them"
7182    )?;
7183    writeln!(
7184        code,
7185        "                let inner = &escaped[1..escaped.len()-1];"
7186    )?;
7187    writeln!(code, "                sse_body.push_str(&format!(")?;
7188    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
7189    writeln!(code, "                    inner")?;
7190    writeln!(code, "                ));")?;
7191    writeln!(code, "            }}")?;
7192    writeln!(code, "            generated_count += 1;")?;
7193    writeln!(code, "            let logits = model.forward(next_token);")?;
7194    writeln!(code, "            next_token = argmax(&logits);")?;
7195    writeln!(code, "        }}")?;
7196    writeln!(
7197        code,
7198        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7199    )?;
7200    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7201    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
7202    writeln!(code)?;
7203    writeln!(
7204        code,
7205        "        // Final chunk with finish_reason, timing, and DONE sentinel"
7206    )?;
7207    writeln!(code, "        sse_body.push_str(&format!(")?;
7208    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
7209    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
7210    writeln!(code, "        ));")?;
7211    writeln!(code)?;
7212    writeln!(
7213        code,
7214        "        let resp = tiny_http::Response::from_string(sse_body)"
7215    )?;
7216    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
7217    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
7218    writeln!(code, "        request.respond(resp).ok();")?;
7219
7220    writeln!(code, "    }} else {{")?;
7221
7222    // -- Non-streaming response --
7223    writeln!(
7224        code,
7225        "        // Non-streaming: generate all tokens, return JSON"
7226    )?;
7227    writeln!(code, "        let gen_start = Instant::now();")?;
7228    writeln!(code, "        let mut generated_count: usize = 0;")?;
7229    writeln!(code, "        let mut generated = String::new();")?;
7230    writeln!(code, "        for _ in 0..max_tokens {{")?;
7231    writeln!(code, "            if next_token == 2 {{ break; }}")?;
7232    writeln!(
7233        code,
7234        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7235    )?;
7236    writeln!(code, "                generated.push_str(&text);")?;
7237    writeln!(code, "            }}")?;
7238    writeln!(code, "            generated_count += 1;")?;
7239    writeln!(code, "            let logits = model.forward(next_token);")?;
7240    writeln!(code, "            next_token = argmax(&logits);")?;
7241    writeln!(code, "        }}")?;
7242    writeln!(
7243        code,
7244        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7245    )?;
7246    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7247    writeln!(code)?;
7248    writeln!(code, "        let resp_json = serde_json::json!({{")?;
7249    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
7250    writeln!(code, "            \"object\": \"chat.completion\",")?;
7251    writeln!(code, "            \"choices\": [{{")?;
7252    writeln!(code, "                \"index\": 0,")?;
7253    writeln!(code, "                \"message\": {{")?;
7254    writeln!(code, "                    \"role\": \"assistant\",")?;
7255    writeln!(code, "                    \"content\": generated")?;
7256    writeln!(code, "                }},")?;
7257    writeln!(code, "                \"finish_reason\": \"stop\"")?;
7258    writeln!(code, "            }}],")?;
7259    writeln!(code, "            \"usage\": {{")?;
7260    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
7261    writeln!(
7262        code,
7263        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
7264    )?;
7265    writeln!(
7266        code,
7267        "                \"generation_tokens\": generated_count,"
7268    )?;
7269    writeln!(
7270        code,
7271        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
7272    )?;
7273    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
7274    writeln!(code, "            }}")?;
7275    writeln!(code, "        }});")?;
7276    writeln!(
7277        code,
7278        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
7279    )?;
7280    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7281    writeln!(code, "        request.respond(resp).ok();")?;
7282    writeln!(code, "    }}")?;
7283    writeln!(code, "}}")?;
7284
7285    Ok(code)
7286}
7287
7288// ---------------------------------------------------------------------------
7289// Tests
7290// ---------------------------------------------------------------------------
7291
7292#[cfg(test)]
7293mod tests {
7294    use super::*;
7295    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7296
7297    fn minimal_config() -> ModelConfig {
7298        ModelConfig {
7299            architecture: Architecture::Llama,
7300            hidden_size: 64,
7301            intermediate_size: 128,
7302            num_layers: 2,
7303            num_attention_heads: 4,
7304            num_kv_heads: 4,
7305            head_dim: 16,
7306            vocab_size: 256,
7307            max_seq_len: 512,
7308            rms_norm_eps: 1e-5,
7309            rope_theta: 10000.0,
7310            dtype: DType::F32,
7311            sliding_window_size: None,
7312            qkv_bias: false,
7313            hidden_activation: HiddenActivation::SiLU,
7314        }
7315    }
7316
7317    fn minimal_graph() -> Graph {
7318        Graph::new("test-metal").with_config(minimal_config())
7319    }
7320
7321    #[test]
7322    fn generate_metal_project_creates_files() {
7323        let dir = tempfile::tempdir().unwrap();
7324        let graph = minimal_graph();
7325        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7326
7327        assert!(
7328            dir.path().join("Cargo.toml").exists(),
7329            "Cargo.toml should be created"
7330        );
7331        assert!(
7332            dir.path().join("src/model.rs").exists(),
7333            "src/model.rs should be created"
7334        );
7335        assert!(
7336            dir.path().join("src/main.rs").exists(),
7337            "src/main.rs should be created"
7338        );
7339        assert!(
7340            dir.path().join("shaders/kernels.metal").exists(),
7341            "shaders/kernels.metal should be created"
7342        );
7343    }
7344
7345    #[test]
7346    fn generated_cargo_toml_has_metal_dep() {
7347        let toml = generate_cargo_toml("my-model");
7348        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7349        assert!(
7350            toml.contains("tokenizers"),
7351            "Cargo.toml should depend on tokenizers"
7352        );
7353        assert!(
7354            toml.contains("memmap2"),
7355            "Cargo.toml should depend on memmap2"
7356        );
7357        assert!(toml.contains("half"), "Cargo.toml should depend on half");
7358    }
7359
7360    #[test]
7361    fn generated_model_rs_contains_metal_code() {
7362        let config = minimal_config();
7363        let model_rs = generate_model_rs(&config).unwrap();
7364
7365        assert!(
7366            model_rs.contains("pub struct MetalModel"),
7367            "model.rs should define MetalModel struct"
7368        );
7369        assert!(
7370            model_rs.contains("matmul_pipeline: ComputePipelineState"),
7371            "MetalModel should have matmul_pipeline field"
7372        );
7373        assert!(
7374            model_rs.contains("Device::system_default()"),
7375            "model.rs should use Metal device"
7376        );
7377        assert!(
7378            model_rs.contains("new_library_with_source"),
7379            "model.rs should compile Metal shaders"
7380        );
7381        assert!(
7382            model_rs.contains("fn new(weights: &[u8])"),
7383            "MetalModel should implement new()"
7384        );
7385        assert!(
7386            model_rs.contains("fn forward(&mut self, token_id: u32)"),
7387            "MetalModel should implement forward()"
7388        );
7389    }
7390
7391    #[test]
7392    fn generated_shaders_contain_kernel_names() {
7393        let shaders = generate_metal_shaders(&minimal_config());
7394
7395        assert!(
7396            shaders.contains("kernel void matmul_vec"),
7397            "shaders should contain matmul_vec kernel"
7398        );
7399        assert!(
7400            shaders.contains("kernel void rms_norm"),
7401            "shaders should contain rms_norm kernel"
7402        );
7403        assert!(
7404            shaders.contains("kernel void rope"),
7405            "shaders should contain rope kernel"
7406        );
7407        assert!(
7408            shaders.contains("kernel void softmax"),
7409            "shaders should contain softmax kernel"
7410        );
7411        assert!(
7412            shaders.contains("kernel void silu_mul("),
7413            "shaders should contain silu_mul kernel"
7414        );
7415        assert!(
7416            shaders.contains("kernel void silu_mul_fused"),
7417            "shaders should contain silu_mul_fused kernel"
7418        );
7419        assert!(
7420            shaders.contains("kernel void elementwise_add"),
7421            "shaders should contain elementwise_add kernel"
7422        );
7423        assert!(
7424            shaders.contains("kernel void attention"),
7425            "shaders should contain attention kernel"
7426        );
7427        assert!(
7428            shaders.contains("kernel void add_inplace"),
7429            "shaders should contain add_inplace kernel"
7430        );
7431        assert!(
7432            shaders.contains("kernel void copy_buffer"),
7433            "shaders should contain copy_buffer kernel"
7434        );
7435        assert!(
7436            shaders.contains("kernel void copy_offset"),
7437            "shaders should contain copy_offset kernel"
7438        );
7439    }
7440
7441    #[test]
7442    fn generated_shaders_use_simdgroup_features() {
7443        let shaders = generate_metal_shaders(&minimal_config());
7444
7445        assert!(
7446            shaders.contains("threadgroup_barrier"),
7447            "shaders should use threadgroup barriers"
7448        );
7449        assert!(
7450            shaders.contains("threadgroup float"),
7451            "shaders should use threadgroup shared memory"
7452        );
7453        assert!(
7454            shaders.contains("thread_index_in_threadgroup"),
7455            "shaders should use threadgroup indexing"
7456        );
7457        assert!(
7458            shaders.contains("simd_sum"),
7459            "shaders should use simd_sum for warp-level reduction"
7460        );
7461        assert!(
7462            shaders.contains("simd_max"),
7463            "attention kernel should use simd_max for cooperative softmax"
7464        );
7465        assert!(
7466            shaders.contains("thread_index_in_simdgroup"),
7467            "shaders should use simdgroup lane indexing"
7468        );
7469        assert!(
7470            shaders.contains("simdgroup_index_in_threadgroup"),
7471            "shaders should use simdgroup indexing within threadgroup"
7472        );
7473        assert!(
7474            shaders.contains("float4"),
7475            "matmul_vec should use float4 vectorized loads"
7476        );
7477    }
7478
7479    #[test]
7480    fn generated_main_rs_has_tokenizer_usage() {
7481        let config = minimal_config();
7482        let main_rs = generate_main_rs("test-model", &config).unwrap();
7483
7484        assert!(
7485            main_rs.contains("tokenizers::Tokenizer"),
7486            "main.rs should use tokenizers crate"
7487        );
7488        assert!(
7489            main_rs.contains("MetalModel::new"),
7490            "main.rs should call MetalModel::new"
7491        );
7492        assert!(
7493            main_rs.contains("model.forward"),
7494            "main.rs should call model.forward"
7495        );
7496        assert!(
7497            main_rs.contains("memmap2"),
7498            "main.rs should use memmap2 for zero-copy weight loading"
7499        );
7500    }
7501
7502    #[test]
7503    fn missing_config_returns_error() {
7504        let dir = tempfile::tempdir().unwrap();
7505        let graph = Graph::new("no-config");
7506        let result = generate_metal_project(&graph, dir.path(), "fail");
7507        assert!(
7508            matches!(result, Err(MetalCodegenError::MissingConfig)),
7509            "should fail with MissingConfig when graph has no config"
7510        );
7511    }
7512
7513    #[test]
7514    fn sanitize_name_works() {
7515        assert_eq!(sanitize_name("My Model!"), "my-model");
7516        assert_eq!(sanitize_name("test_model"), "test-model");
7517        assert_eq!(sanitize_name("simple"), "simple");
7518    }
7519
7520    #[test]
7521    fn generated_forward_uses_single_command_buffer() {
7522        let config = minimal_config();
7523        let model_rs = generate_model_rs(&config).unwrap();
7524
7525        // The forward function should create exactly one command buffer.
7526        // Use the exact signature to avoid matching forward_prefill/forward_profile.
7527        let forward_start = model_rs
7528            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7529            .unwrap();
7530        let forward_body = &model_rs[forward_start..];
7531        // End at the next pub/private method
7532        let forward_end = forward_body
7533            .find("\n    pub fn forward_profile")
7534            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7535            .or_else(|| forward_body.find("\n    fn dispatch_"))
7536            .unwrap_or(forward_body.len());
7537        let forward_code = &forward_body[..forward_end];
7538
7539        // Should have exactly one new_command_buffer call
7540        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7541        assert_eq!(
7542            cmd_buf_count, 1,
7543            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7544        );
7545
7546        // Should have exactly one commit call
7547        let commit_count = forward_code.matches("cmd.commit()").count();
7548        assert_eq!(
7549            commit_count, 1,
7550            "forward() should commit exactly once, found {commit_count}"
7551        );
7552
7553        // Should wait: once for cmd + possibly once for prev_cmd drain
7554        let wait_count = forward_code.matches("wait_until_completed()").count();
7555        assert!(
7556            wait_count >= 1 && wait_count <= 2,
7557            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7558        );
7559    }
7560
7561    #[test]
7562    fn generated_model_has_preallocated_working_buffers() {
7563        let config = minimal_config();
7564        let model_rs = generate_model_rs(&config).unwrap();
7565
7566        for buf_name in &[
7567            "normed_buf",
7568            "qkv_buf",
7569            "attn_out_buf",
7570            "attn_proj_buf",
7571            "gate_up_buf",
7572            "ffn_hidden_buf",
7573            "ffn_out_buf",
7574            "add_tmp_buf",
7575        ] {
7576            assert!(
7577                model_rs.contains(&format!("{buf_name}: Buffer")),
7578                "MetalModel should have pre-allocated {buf_name} field"
7579            );
7580        }
7581    }
7582
7583    #[test]
7584    fn generated_dispatch_helpers_take_compute_encoder_ref() {
7585        let config = minimal_config();
7586        let model_rs = generate_model_rs(&config).unwrap();
7587
7588        for method in &[
7589            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7590            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7591            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7592            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7593            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7594            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7595            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7596            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7597            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7598            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7599            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7600            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7601            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7602        ] {
7603            assert!(
7604                model_rs.contains(method),
7605                "model.rs should contain dispatch helper: {method}"
7606            );
7607        }
7608    }
7609
7610    #[test]
7611    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7612        let config = minimal_config();
7613        let model_rs = generate_model_rs(&config).unwrap();
7614
7615        // Find dispatch helpers section and check none create their own encoders
7616        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7617        let helpers_code = &model_rs[helpers_start..];
7618
7619        // None of the dispatch_ helpers should call new_command_buffer
7620        assert!(
7621            !helpers_code.contains("self.queue.new_command_buffer()"),
7622            "dispatch helpers should not create their own command buffers"
7623        );
7624
7625        // None should create their own compute encoders
7626        assert!(
7627            !helpers_code.contains("new_compute_command_encoder()"),
7628            "dispatch helpers should not create their own compute encoders"
7629        );
7630
7631        // None should call end_encoding
7632        assert!(
7633            !helpers_code.contains("end_encoding()"),
7634            "dispatch helpers should not call end_encoding"
7635        );
7636
7637        // None should call commit or wait
7638        assert!(
7639            !helpers_code.contains(".commit()"),
7640            "dispatch helpers should not commit command buffers"
7641        );
7642        assert!(
7643            !helpers_code.contains("wait_until_completed"),
7644            "dispatch helpers should not wait on command buffers"
7645        );
7646    }
7647
7648    #[test]
7649    fn generated_forward_batches_compute_encoders() {
7650        let config = minimal_config();
7651        let model_rs = generate_model_rs(&config).unwrap();
7652
7653        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
7654        let forward_start = model_rs
7655            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7656            .unwrap();
7657        let forward_body = &model_rs[forward_start..];
7658        let forward_end = forward_body
7659            .find("\n    pub fn forward_profile")
7660            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7661            .or_else(|| forward_body.find("\n    fn dispatch_"))
7662            .unwrap_or(forward_body.len());
7663        let forward_code = &forward_body[..forward_end];
7664
7665        // Forward should not allocate new buffers
7666        assert!(
7667            !forward_code.contains("device.new_buffer"),
7668            "forward() should not allocate new buffers per call"
7669        );
7670
7671        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
7672        // Copy operations use compute copy kernels instead of blit encoders.
7673        let compute_encoder_count = forward_code
7674            .matches("new_compute_command_encoder()")
7675            .count();
7676        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7677
7678        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
7679        assert_eq!(
7680            compute_encoder_count, 1,
7681            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7682        );
7683        assert_eq!(
7684            blit_encoder_count, 0,
7685            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7686        );
7687    }
7688
7689    #[test]
7690    fn generated_forward_uses_add_inplace() {
7691        let config = minimal_config();
7692        let model_rs = generate_model_rs(&config).unwrap();
7693
7694        // Should use in-place add (no blit copy-back needed)
7695        assert!(
7696            model_rs.contains("dispatch_add_inplace"),
7697            "forward() should use dispatch_add_inplace for residual connections"
7698        );
7699        assert!(
7700            model_rs.contains("add_inplace_pipeline"),
7701            "MetalModel should have add_inplace_pipeline"
7702        );
7703    }
7704
7705    fn minimal_q8_config() -> ModelConfig {
7706        ModelConfig {
7707            architecture: Architecture::Llama,
7708            hidden_size: 64,
7709            intermediate_size: 128,
7710            num_layers: 2,
7711            num_attention_heads: 4,
7712            num_kv_heads: 4,
7713            head_dim: 16,
7714            vocab_size: 256,
7715            max_seq_len: 512,
7716            rms_norm_eps: 1e-5,
7717            rope_theta: 10000.0,
7718            dtype: DType::Q8_0,
7719            sliding_window_size: None,
7720            qkv_bias: false,
7721            hidden_activation: HiddenActivation::SiLU,
7722        }
7723    }
7724
7725    #[test]
7726    fn generated_shaders_contain_q8_kernel() {
7727        let shaders = generate_metal_shaders(&minimal_config());
7728
7729        assert!(
7730            shaders.contains("kernel void matmul_vec_q8"),
7731            "shaders should contain matmul_vec_q8 kernel"
7732        );
7733        assert!(
7734            shaders.contains("device const uchar* matrix"),
7735            "matmul_vec_q8 should accept raw Q8_0 bytes"
7736        );
7737        assert!(
7738            shaders.contains("packed_short4"),
7739            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7740        );
7741        assert!(
7742            shaders.contains("as_type<char2>"),
7743            "matmul_vec_q8 should bitcast short lanes to char2"
7744        );
7745        assert!(
7746            shaders.contains("device const half*"),
7747            "matmul_vec_q8 should read f16 scale via half pointer"
7748        );
7749    }
7750
7751    #[test]
7752    fn generated_model_uses_fused_qkv_projections() {
7753        let config = minimal_config();
7754        let model_rs = generate_model_rs(&config).unwrap();
7755
7756        // Should have fused QKV weight in layer buffers
7757        assert!(
7758            model_rs.contains("qkv_weight: Buffer"),
7759            "LayerBuffers should have fused qkv_weight field"
7760        );
7761        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
7762        assert!(
7763            !model_rs.contains("    q_weight: Buffer"),
7764            "LayerBuffers should not have separate q_weight field"
7765        );
7766        assert!(
7767            !model_rs.contains("    k_weight: Buffer"),
7768            "LayerBuffers should not have separate k_weight field"
7769        );
7770        assert!(
7771            !model_rs.contains("    v_weight: Buffer"),
7772            "LayerBuffers should not have separate v_weight field"
7773        );
7774
7775        // Should have fused gate_up_weight
7776        assert!(
7777            model_rs.contains("gate_up_weight: Buffer"),
7778            "LayerBuffers should have fused gate_up_weight field"
7779        );
7780        // Should NOT have separate gate/up weight fields
7781        assert!(
7782            !model_rs.contains("    gate_weight: Buffer"),
7783            "LayerBuffers should not have separate gate_weight field"
7784        );
7785        assert!(
7786            !model_rs.contains("    up_weight: Buffer"),
7787            "LayerBuffers should not have separate up_weight field"
7788        );
7789
7790        // Should have fused working buffers
7791        assert!(
7792            model_rs.contains("qkv_buf: Buffer"),
7793            "MetalModel should have fused qkv_buf"
7794        );
7795        assert!(
7796            model_rs.contains("gate_up_buf: Buffer"),
7797            "MetalModel should have fused gate_up_buf"
7798        );
7799
7800        // Forward pass should use fused dispatch
7801        assert!(
7802            model_rs.contains("dispatch_silu_mul_fused"),
7803            "forward pass should use dispatch_silu_mul_fused"
7804        );
7805        assert!(
7806            model_rs.contains("dispatch_rope_offset"),
7807            "forward pass should use dispatch_rope_offset for fused QKV"
7808        );
7809        assert!(
7810            model_rs.contains("dispatch_attention_offset"),
7811            "forward pass should use dispatch_attention_offset for fused QKV"
7812        );
7813    }
7814
7815    #[test]
7816    fn q8_model_has_matmul_q8_pipeline() {
7817        let config = minimal_q8_config();
7818        let model_rs = generate_model_rs(&config).unwrap();
7819
7820        assert!(
7821            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7822            "MetalModel should have matmul_q8_pipeline field"
7823        );
7824        assert!(
7825            model_rs.contains("matmul_q8_pipeline,"),
7826            "MetalModel Self should include matmul_q8_pipeline"
7827        );
7828    }
7829
7830    #[test]
7831    fn q8_model_uses_dispatch_matmul_q8() {
7832        let config = minimal_q8_config();
7833        let model_rs = generate_model_rs(&config).unwrap();
7834
7835        assert!(
7836            model_rs.contains("dispatch_matmul_q8"),
7837            "Q8_0 model should use dispatch_matmul_q8 for projections"
7838        );
7839        assert!(
7840            model_rs.contains("fn dispatch_matmul_q8"),
7841            "model.rs should define dispatch_matmul_q8 method"
7842        );
7843    }
7844
7845    #[test]
7846    fn q8_model_loads_raw_bytes_not_dequantized() {
7847        let config = minimal_q8_config();
7848        let model_rs = generate_model_rs(&config).unwrap();
7849
7850        // Should NOT contain dequantization code
7851        assert!(
7852            !model_rs.contains("f16_to_f32"),
7853            "Q8_0 model should not dequantize weights to f32"
7854        );
7855        assert!(
7856            !model_rs.contains("f32_data"),
7857            "Q8_0 model should not create f32 weight data"
7858        );
7859
7860        // Should load raw Q8_0 bytes directly
7861        assert!(
7862            model_rs.contains("total_raw as u64"),
7863            "Q8_0 model should load raw bytes into Metal buffer"
7864        );
7865    }
7866
7867    #[test]
7868    fn q8_model_norms_stay_f32() {
7869        let config = minimal_q8_config();
7870        let model_rs = generate_model_rs(&config).unwrap();
7871
7872        // Norm weights should still use f32 buffers
7873        assert!(
7874            model_rs.contains("let attn_norm = next_f32_buffer"),
7875            "attn_norm should use f32 buffer even for Q8_0 models"
7876        );
7877        assert!(
7878            model_rs.contains("let ffn_norm = next_f32_buffer"),
7879            "ffn_norm should use f32 buffer even for Q8_0 models"
7880        );
7881        assert!(
7882            model_rs.contains("let norm_buf = next_f32_buffer"),
7883            "final norm should use f32 buffer even for Q8_0 models"
7884        );
7885    }
7886
7887    #[test]
7888    fn q8_model_uses_fused_weight_loading() {
7889        let config = minimal_q8_config();
7890        let model_rs = generate_model_rs(&config).unwrap();
7891
7892        // Should use fused Q8 buffer loading for QKV
7893        assert!(
7894            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7895            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7896        );
7897        // Should use fused Q8 buffer loading for gate+up
7898        assert!(
7899            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7900            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7901        );
7902        // Should still use regular q8 buffer for individual weights
7903        assert!(
7904            model_rs.contains("let o_weight = next_q8_buffer"),
7905            "Q8_0 model should use next_q8_buffer for O weight"
7906        );
7907        assert!(
7908            model_rs.contains("let down_weight = next_q8_buffer"),
7909            "Q8_0 model should use next_q8_buffer for down weight"
7910        );
7911    }
7912
7913    #[test]
7914    fn f32_model_does_not_use_q8_dispatch() {
7915        let config = minimal_config();
7916        let model_rs = generate_model_rs(&config).unwrap();
7917
7918        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
7919        let forward_start = model_rs
7920            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7921            .unwrap();
7922        let forward_body = &model_rs[forward_start..];
7923        let forward_end = forward_body
7924            .find("\n    fn dispatch_")
7925            .unwrap_or(forward_body.len());
7926        let forward_code = &forward_body[..forward_end];
7927
7928        assert!(
7929            !forward_code.contains("dispatch_matmul_q8"),
7930            "f32 model forward should not use dispatch_matmul_q8"
7931        );
7932    }
7933
7934    #[test]
7935    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7936        let config = minimal_q8_config();
7937        let model_rs = generate_model_rs(&config).unwrap();
7938
7939        assert!(
7940            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7941            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7942        );
7943    }
7944
7945    #[test]
7946    fn generated_model_has_double_buffered_prefill() {
7947        let config = minimal_config();
7948        let model_rs = generate_model_rs(&config).unwrap();
7949
7950        // MetalModel should have prev_cmd field for double-buffered prefill
7951        assert!(
7952            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7953            "MetalModel should have prev_cmd field for double-buffered prefill"
7954        );
7955
7956        // Should have forward_prefill method
7957        assert!(
7958            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7959            "MetalModel should have forward_prefill method"
7960        );
7961
7962        // forward() should drain prev_cmd at the start
7963        assert!(
7964            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7965            "forward() should drain prev_cmd from previous prefill"
7966        );
7967    }
7968
7969    #[test]
7970    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7971        let config = minimal_config();
7972        let main_rs = generate_main_rs("test-model", &config).unwrap();
7973
7974        assert!(
7975            main_rs.contains("forward_prefill"),
7976            "main.rs should use forward_prefill for intermediate prompt tokens"
7977        );
7978        assert!(
7979            main_rs.contains("double-buffered"),
7980            "main.rs should document double-buffered prefill"
7981        );
7982    }
7983
7984    #[test]
7985    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7986        let shaders = generate_metal_shaders(&minimal_config());
7987
7988        assert!(
7989            shaders.contains("packed_short4"),
7990            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7991        );
7992        assert!(
7993            shaders.contains("d0[0]"),
7994            "matmul_vec_q8 should index the wide pointer for row 0"
7995        );
7996        assert!(
7997            shaders.contains("as_type<char2>"),
7998            "matmul_vec_q8 should bitcast short lanes to char2"
7999        );
8000        assert!(
8001            shaders.contains("dot("),
8002            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
8003        );
8004    }
8005
8006    // ── Q4_0 tests ──────────────────────────────────────────────────────
8007
8008    fn minimal_q4_config() -> ModelConfig {
8009        ModelConfig {
8010            architecture: Architecture::Llama,
8011            hidden_size: 64,
8012            intermediate_size: 128,
8013            num_layers: 2,
8014            num_attention_heads: 4,
8015            num_kv_heads: 4,
8016            head_dim: 16,
8017            vocab_size: 256,
8018            max_seq_len: 512,
8019            rms_norm_eps: 1e-5,
8020            rope_theta: 10000.0,
8021            dtype: DType::Q4_0,
8022            sliding_window_size: None,
8023            qkv_bias: false,
8024            hidden_activation: HiddenActivation::SiLU,
8025        }
8026    }
8027
8028    #[test]
8029    fn generated_shaders_contain_q4_kernel() {
8030        let shaders = generate_metal_shaders(&minimal_config());
8031
8032        assert!(
8033            shaders.contains("kernel void matmul_vec_q4"),
8034            "shaders should contain matmul_vec_q4 kernel"
8035        );
8036        assert!(
8037            shaders.contains("Q4_ROWS_PER_TG"),
8038            "shaders should define Q4_ROWS_PER_TG constant"
8039        );
8040        assert!(
8041            shaders.contains("Q4_ROWS_PER_SG"),
8042            "shaders should define Q4_ROWS_PER_SG constant"
8043        );
8044    }
8045
8046    #[test]
8047    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
8048        let shaders = generate_metal_shaders(&minimal_config());
8049
8050        // Q4_0 kernel should use uchar4 for packed byte loads
8051        assert!(
8052            shaders.contains("uchar4"),
8053            "matmul_vec_q4 should use uchar4 for packed byte loads"
8054        );
8055        // Should unpack nibbles with &0xF and >>4
8056        assert!(
8057            shaders.contains("&0xF"),
8058            "matmul_vec_q4 should extract low nibble with &0xF"
8059        );
8060        assert!(
8061            shaders.contains(">>4"),
8062            "matmul_vec_q4 should extract high nibble with >>4"
8063        );
8064        // Should subtract 8 to convert unsigned to signed
8065        assert!(
8066            shaders.contains("-8)"),
8067            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
8068        );
8069        // Should use 18-byte block size
8070        assert!(
8071            shaders.contains("blk * 18"),
8072            "matmul_vec_q4 should use 18-byte block stride"
8073        );
8074    }
8075
8076    #[test]
8077    fn q4_model_has_matmul_q4_pipeline() {
8078        let config = minimal_q4_config();
8079        let model_rs = generate_model_rs(&config).unwrap();
8080
8081        assert!(
8082            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
8083            "MetalModel should have matmul_q4_pipeline field"
8084        );
8085        assert!(
8086            model_rs.contains("matmul_q4_pipeline,"),
8087            "MetalModel Self should include matmul_q4_pipeline"
8088        );
8089    }
8090
8091    #[test]
8092    fn q4_model_uses_dispatch_matmul_q4() {
8093        let config = minimal_q4_config();
8094        let model_rs = generate_model_rs(&config).unwrap();
8095
8096        assert!(
8097            model_rs.contains("dispatch_matmul_q4"),
8098            "Q4_0 model should use dispatch_matmul_q4 for projections"
8099        );
8100        assert!(
8101            model_rs.contains("fn dispatch_matmul_q4"),
8102            "model.rs should define dispatch_matmul_q4 method"
8103        );
8104    }
8105
8106    #[test]
8107    fn q4_model_loads_raw_bytes_not_dequantized() {
8108        let config = minimal_q4_config();
8109        let model_rs = generate_model_rs(&config).unwrap();
8110
8111        // Should NOT contain dequantization code
8112        assert!(
8113            !model_rs.contains("f16_to_f32"),
8114            "Q4_0 model should not dequantize weights to f32"
8115        );
8116
8117        // Should load raw Q4_0 bytes directly
8118        assert!(
8119            model_rs.contains("total_raw as u64"),
8120            "Q4_0 model should load raw bytes into Metal buffer"
8121        );
8122    }
8123
8124    #[test]
8125    fn q4_model_norms_stay_f32() {
8126        let config = minimal_q4_config();
8127        let model_rs = generate_model_rs(&config).unwrap();
8128
8129        assert!(
8130            model_rs.contains("let attn_norm = next_f32_buffer"),
8131            "attn_norm should use f32 buffer even for Q4_0 models"
8132        );
8133        assert!(
8134            model_rs.contains("let ffn_norm = next_f32_buffer"),
8135            "ffn_norm should use f32 buffer even for Q4_0 models"
8136        );
8137        assert!(
8138            model_rs.contains("let norm_buf = next_f32_buffer"),
8139            "final norm should use f32 buffer even for Q4_0 models"
8140        );
8141    }
8142
8143    #[test]
8144    fn q4_model_uses_fused_weight_loading() {
8145        let config = minimal_q4_config();
8146        let model_rs = generate_model_rs(&config).unwrap();
8147
8148        assert!(
8149            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
8150            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
8151        );
8152        assert!(
8153            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
8154            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
8155        );
8156        assert!(
8157            model_rs.contains("let o_weight = next_q4_buffer"),
8158            "Q4_0 model should use next_q4_buffer for O weight"
8159        );
8160        assert!(
8161            model_rs.contains("let down_weight = next_q4_buffer"),
8162            "Q4_0 model should use next_q4_buffer for down weight"
8163        );
8164    }
8165
8166    #[test]
8167    fn attention_flash_batch_kernel_exists() {
8168        // The flash kernel is still wired into the library (pipeline, kernel
8169        // source).  Dispatch is currently routed to the legacy path pending a
8170        // fix for a numerical issue discovered after the prompt-chunking fix.
8171        let config = minimal_config();
8172        let model_rs = generate_model_rs(&config).unwrap();
8173        let shaders = generate_metal_shaders(&config);
8174
8175        assert!(
8176            shaders.contains("kernel void attention_flash_batch"),
8177            "shaders.metal must still contain the attention_flash_batch kernel"
8178        );
8179        assert!(
8180            shaders.contains("FLASH_K_TILE"),
8181            "flash kernel must tile K/V with a FLASH_K_TILE constant"
8182        );
8183        assert!(
8184            model_rs.contains("attention_flash_batch_pipeline"),
8185            "MetalModel must register the flash attention pipeline"
8186        );
8187    }
8188
8189    #[test]
8190    fn decode_uses_fused_rope_and_kv_copy() {
8191        // v0.7.2: single-token decode reuses `rope_qk_batch` (M=1) and
8192        // `copy_kv_both_batch` (M=1) instead of calling the separate
8193        // rope_offset + copy_from_offset_f16 pairs, saving 2 dispatches +
8194        // barriers per layer.
8195        let config = minimal_config();
8196        let model_rs = generate_model_rs(&config).unwrap();
8197
8198        // forward() and forward_profile() each have a decode loop.
8199        // Locate `pub fn forward(` and `pub fn forward_profile(` and verify
8200        // the fused dispatches appear between the function header and the
8201        // start of forward_prefill (which also uses these fused helpers).
8202        let forward_start = model_rs.find("pub fn forward(").expect("forward missing");
8203        let forward_end = model_rs[forward_start..]
8204            .find("pub fn forward_prefill(")
8205            .expect("forward_prefill missing");
8206        let forward_body = &model_rs[forward_start..forward_start + forward_end];
8207
8208        assert!(
8209            forward_body.contains("dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1"),
8210            "single-token forward must use fused rope_qk_batch with M=1"
8211        );
8212        assert!(
8213            forward_body.contains("dispatch_copy_kv_both_batch(&enc, &self.qkv_buf,"),
8214            "single-token forward must use fused copy_kv_both_batch"
8215        );
8216        assert!(
8217            !forward_body.contains("dispatch_copy_from_offset_f16"),
8218            "decode path should no longer call the per-K/per-V copy"
8219        );
8220    }
8221
8222    #[test]
8223    fn kv_cache_stored_as_f16() {
8224        // v0.7.1: KV cache is f16 (2 bytes/element) instead of f32 to halve
8225        // attention memory bandwidth and KV RAM footprint.  All attention
8226        // kernels must read K/V as `device const half*`; copy kernels must
8227        // write half; the f32->f16 copy kernel must be wired for single-
8228        // token decode KV writes.
8229        let config = minimal_config();
8230        let shaders = generate_metal_shaders(&config);
8231        let model_rs = generate_model_rs(&config).unwrap();
8232
8233        for kernel in [
8234            "kernel void attention(",
8235            "kernel void attention_batch(",
8236            "kernel void attention_flash_batch(",
8237            "kernel void attention_mma_flash_batch(",
8238        ] {
8239            let start = shaders
8240                .find(kernel)
8241                .unwrap_or_else(|| panic!("kernel {kernel} missing"));
8242            // Slice through the signature block — scan for "){" which closes
8243            // the parameter list at the top of every kernel's body.
8244            let sig_end = shaders[start..]
8245                .find("){")
8246                .unwrap_or_else(|| shaders[start..].find(") {").unwrap());
8247            let sig = &shaders[start..start + sig_end];
8248            assert!(
8249                sig.contains("device const half*")
8250                    && sig.contains("k_cache")
8251                    && sig.contains("v_cache"),
8252                "{kernel} must read k_cache/v_cache as `device const half*`"
8253            );
8254            assert!(
8255                !sig.contains("device const float* k_cache")
8256                    && !sig.contains("device const float*  k_cache"),
8257                "{kernel} still reads k_cache as float"
8258            );
8259        }
8260
8261        assert!(
8262            shaders.contains("kernel void copy_f32_to_f16_offset"),
8263            "f32->f16 copy kernel must be present for single-token decode KV writes"
8264        );
8265        assert!(
8266            model_rs.contains("dispatch_copy_from_offset_f16"),
8267            "single-token decode must dispatch the f32->f16 KV copy"
8268        );
8269    }
8270
8271    #[test]
8272    fn decode_attention_uses_half4_vectorized_loads() {
8273        // v0.7.3: vectorized K/V via half4 + float4.
8274        // v0.7.4: V-step restructured — one simdgroup per d4 chunk, 32 lanes
8275        // partition seq_len with simd_sum reduction. Fixes head_dim=64 under-
8276        // utilization (16 productive threads → 256 productive threads).
8277        let config = minimal_config();
8278        let shaders = generate_metal_shaders(&config);
8279
8280        let start = shaders
8281            .find("kernel void attention(")
8282            .expect("decode attention kernel missing");
8283        // Slice just the decode kernel body — stop at the next kernel.
8284        let end_rel = shaders[start + 1..]
8285            .find("kernel void ")
8286            .expect("next kernel missing");
8287        let body = &shaders[start..start + 1 + end_rel];
8288
8289        assert!(
8290            body.contains("device const half4*"),
8291            "decode attention must half4-load K/V"
8292        );
8293        assert!(
8294            body.contains("device const float4*"),
8295            "decode attention must float4-load Q"
8296        );
8297        assert!(
8298            body.contains("device float4*"),
8299            "decode attention must float4-store output"
8300        );
8301        assert!(
8302            body.contains("head_dim4"),
8303            "decode attention must iterate head_dim in chunks of 4"
8304        );
8305        // v0.7.4: V-step uses simd_sum per component to reduce across 32 lanes.
8306        assert!(
8307            body.contains("simd_sum(partial.x)")
8308                && body.contains("simd_sum(partial.y)")
8309                && body.contains("simd_sum(partial.z)")
8310                && body.contains("simd_sum(partial.w)"),
8311            "decode attention V-step must reduce float4 partials via simd_sum"
8312        );
8313        // And the V-step outer loop must iterate d4 across simdgroups (simd_id), not threads (tid).
8314        assert!(
8315            body.contains("for (uint d4 = simd_id; d4 < head_dim4; d4 += 8)"),
8316            "decode attention V-step must partition d4 across simdgroups"
8317        );
8318    }
8319
8320    #[test]
8321    fn attention_mma_flash_batch_kernel_wired() {
8322        // MMA-accelerated flash attention (issue #212).  Default-on in v0.7.0
8323        // when HEAD_DIM ≤ 128 and num_tokens ≥ 8.  FORGE_MMA_ATTN=0 opts out.
8324        let config = minimal_config();
8325        let model_rs = generate_model_rs(&config).unwrap();
8326        let shaders = generate_metal_shaders(&config);
8327
8328        assert!(
8329            shaders.contains("kernel void attention_mma_flash_batch"),
8330            "shaders.metal must contain the MMA flash kernel"
8331        );
8332        assert!(
8333            shaders.contains("FLASH_MMA_Q_BLOCK"),
8334            "MMA flash kernel must define Q_BLOCK tiling constant"
8335        );
8336        assert!(
8337            shaders.contains("simdgroup_multiply_accumulate"),
8338            "MMA flash kernel must use hardware MMA"
8339        );
8340        assert!(
8341            model_rs.contains("attention_mma_flash_batch_pipeline"),
8342            "MetalModel must register the MMA flash pipeline"
8343        );
8344        assert!(
8345            model_rs.contains("mma_opt_out"),
8346            "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
8347        );
8348        assert!(
8349            model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
8350            "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
8351        );
8352    }
8353
8354    #[test]
8355    fn forward_prefill_batch_chunks_by_max_batch_size() {
8356        // Regression: prior to v0.6.4 forward_prefill_batch truncated prompts
8357        // longer than MAX_BATCH_SIZE (512) tokens via `.min(MAX_BATCH_SIZE)`,
8358        // silently dropping the middle of long prompts.  Must now loop over
8359        // MAX_BATCH_SIZE-sized chunks and carry KV-cache state across them.
8360        let config = minimal_config();
8361        let model_rs = generate_model_rs(&config).unwrap();
8362        assert!(
8363            model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
8364            "forward_prefill_batch must chunk long prompts"
8365        );
8366        assert!(
8367            !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
8368            "the old truncation path must be gone"
8369        );
8370    }
8371
8372    #[test]
8373    fn qwen2_qkv_bias_wired_through_metal_codegen() {
8374        // Issue #210: the pre-v0.6.2 Metal codegen had zero handling for
8375        // qkv_bias.  Verify that a Qwen2-style config emits the bias buffer,
8376        // loader, pipeline, and dispatch call in the expected places.
8377        let config = ModelConfig {
8378            architecture: Architecture::Qwen2,
8379            qkv_bias: true,
8380            ..minimal_config()
8381        };
8382        let model_rs = generate_model_rs(&config).unwrap();
8383
8384        assert!(
8385            model_rs.contains("qkv_bias: Buffer"),
8386            "Qwen2 LayerBuffers must declare qkv_bias field"
8387        );
8388        assert!(
8389            model_rs.contains("let qkv_bias = next_f32_buffer"),
8390            "Qwen2 layer init must load the bias from the weight blob"
8391        );
8392        assert!(
8393            model_rs.contains("add_bias_batch_pipeline"),
8394            "Qwen2 model struct must include the add_bias_batch_pipeline"
8395        );
8396        assert!(
8397            model_rs.contains("fn dispatch_add_bias_batch"),
8398            "Qwen2 codegen must emit dispatch_add_bias_batch helper"
8399        );
8400        assert!(
8401            model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
8402            "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
8403        );
8404        assert!(
8405            model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
8406            "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
8407        );
8408
8409        // The add_bias_batch MSL kernel must be in the shader source.
8410        let shaders = generate_metal_shaders(&config);
8411        assert!(
8412            shaders.contains("kernel void add_bias_batch"),
8413            "shaders.metal must contain the add_bias_batch kernel"
8414        );
8415    }
8416
8417    #[test]
8418    fn llama_does_not_emit_qkv_bias_machinery() {
8419        // Negative test: non-Qwen2 models must NOT carry the bias dispatch,
8420        // buffer, or pipeline — keeps generated code lean for Llama/Phi/etc.
8421        let config = minimal_config();
8422        assert!(!config.qkv_bias);
8423        let model_rs = generate_model_rs(&config).unwrap();
8424        assert!(
8425            !model_rs.contains("qkv_bias: Buffer"),
8426            "Llama must not have qkv_bias field"
8427        );
8428        assert!(
8429            !model_rs.contains("add_bias_batch_pipeline"),
8430            "Llama must not pull in add_bias_batch_pipeline"
8431        );
8432        assert!(
8433            !model_rs.contains("dispatch_add_bias_batch"),
8434            "Llama must not call dispatch_add_bias_batch"
8435        );
8436    }
8437
8438    #[test]
8439    fn q4_dispatch_helper_takes_compute_encoder_ref() {
8440        let config = minimal_q4_config();
8441        let model_rs = generate_model_rs(&config).unwrap();
8442
8443        assert!(
8444            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8445            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8446        );
8447    }
8448
8449    #[test]
8450    fn f32_model_does_not_use_q4_dispatch() {
8451        let config = minimal_config();
8452        let model_rs = generate_model_rs(&config).unwrap();
8453
8454        let forward_start = model_rs
8455            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8456            .unwrap();
8457        let forward_body = &model_rs[forward_start..];
8458        let forward_end = forward_body
8459            .find("\n    fn dispatch_")
8460            .unwrap_or(forward_body.len());
8461        let forward_code = &forward_body[..forward_end];
8462
8463        assert!(
8464            !forward_code.contains("dispatch_matmul_q4"),
8465            "f32 model forward should not use dispatch_matmul_q4"
8466        );
8467    }
8468
8469    #[test]
8470    fn q4_model_lm_head_uses_q4_buffer() {
8471        let config = minimal_q4_config();
8472        let model_rs = generate_model_rs(&config).unwrap();
8473
8474        assert!(
8475            model_rs.contains("let lm_head_buf = next_q4_buffer"),
8476            "Q4_0 model should use next_q4_buffer for lm_head"
8477        );
8478    }
8479
8480    #[test]
8481    fn vec_tile_size_matches_model_dimensions() {
8482        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
8483        let small = minimal_config();
8484        let shaders_small = generate_metal_shaders(&small);
8485        assert!(
8486            shaders_small.contains("vec_tile[128]"),
8487            "vec_tile should be sized to max(hidden, intermediate) = 128"
8488        );
8489
8490        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
8491        let mut large = minimal_config();
8492        large.hidden_size = 2048;
8493        large.intermediate_size = 8192;
8494        let shaders_large = generate_metal_shaders(&large);
8495        assert!(
8496            shaders_large.contains("vec_tile[8192]"),
8497            "vec_tile should be 8192 for models with intermediate=8192"
8498        );
8499        assert!(
8500            !shaders_large.contains("vec_tile[4096]"),
8501            "vec_tile should NOT be hardcoded to 4096"
8502        );
8503    }
8504
8505    #[test]
8506    fn generated_cargo_toml_has_server_deps() {
8507        let toml = generate_cargo_toml("my-model");
8508        assert!(
8509            toml.contains("tiny_http"),
8510            "Cargo.toml should depend on tiny_http"
8511        );
8512        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8513        assert!(
8514            toml.contains("serde_json"),
8515            "Cargo.toml should depend on serde_json"
8516        );
8517    }
8518
8519    #[test]
8520    fn generated_main_rs_has_serve_mode() {
8521        let config = minimal_config();
8522        let main_rs = generate_main_rs("test-model", &config).unwrap();
8523
8524        assert!(
8525            main_rs.contains("--serve"),
8526            "main.rs should parse --serve flag"
8527        );
8528        assert!(
8529            main_rs.contains("--port"),
8530            "main.rs should parse --port flag"
8531        );
8532        assert!(
8533            main_rs.contains("fn serve("),
8534            "main.rs should define serve function"
8535        );
8536        assert!(
8537            main_rs.contains("tiny_http::Server::http"),
8538            "main.rs should create tiny_http server"
8539        );
8540    }
8541
8542    #[test]
8543    fn generated_main_rs_has_chat_completions_endpoint() {
8544        let config = minimal_config();
8545        let main_rs = generate_main_rs("test-model", &config).unwrap();
8546
8547        assert!(
8548            main_rs.contains("/v1/chat/completions"),
8549            "main.rs should handle /v1/chat/completions endpoint"
8550        );
8551        assert!(
8552            main_rs.contains("/v1/models"),
8553            "main.rs should handle /v1/models endpoint"
8554        );
8555        assert!(
8556            main_rs.contains("/health"),
8557            "main.rs should handle /health endpoint"
8558        );
8559    }
8560
8561    #[test]
8562    fn generated_main_rs_has_sse_streaming() {
8563        let config = minimal_config();
8564        let main_rs = generate_main_rs("test-model", &config).unwrap();
8565
8566        assert!(
8567            main_rs.contains("text/event-stream"),
8568            "main.rs should set SSE content type for streaming"
8569        );
8570        assert!(
8571            main_rs.contains("chat.completion.chunk"),
8572            "main.rs should emit SSE chunks"
8573        );
8574        assert!(
8575            main_rs.contains("[DONE]"),
8576            "main.rs should emit [DONE] sentinel"
8577        );
8578    }
8579
8580    #[test]
8581    fn generated_main_rs_has_chat_message_formatting() {
8582        let config = minimal_config();
8583        let main_rs = generate_main_rs("test-model", &config).unwrap();
8584
8585        assert!(
8586            main_rs.contains("fn format_chat_messages"),
8587            "main.rs should define format_chat_messages function"
8588        );
8589        assert!(
8590            main_rs.contains("<|im_start|>"),
8591            "main.rs should use ChatML format"
8592        );
8593        assert!(
8594            main_rs.contains("<|im_end|>"),
8595            "main.rs should use ChatML format"
8596        );
8597    }
8598
8599    #[test]
8600    fn generated_main_rs_has_request_types() {
8601        let config = minimal_config();
8602        let main_rs = generate_main_rs("test-model", &config).unwrap();
8603
8604        assert!(
8605            main_rs.contains("struct ChatRequest"),
8606            "main.rs should define ChatRequest struct"
8607        );
8608        assert!(
8609            main_rs.contains("struct ChatMessage"),
8610            "main.rs should define ChatMessage struct"
8611        );
8612        assert!(
8613            main_rs.contains("Deserialize"),
8614            "main.rs should derive Deserialize for request types"
8615        );
8616    }
8617
8618    #[test]
8619    fn generated_model_has_reset_method() {
8620        let config = minimal_config();
8621        let model_rs = generate_model_rs(&config).unwrap();
8622
8623        assert!(
8624            model_rs.contains("pub fn reset(&mut self)"),
8625            "model.rs should have a reset() method for multi-request serving"
8626        );
8627        assert!(
8628            model_rs.contains("self.pos = 0"),
8629            "reset() should reset position to 0"
8630        );
8631    }
8632
8633    #[test]
8634    fn generated_main_rs_cli_mode_still_works() {
8635        let config = minimal_config();
8636        let main_rs = generate_main_rs("test-model", &config).unwrap();
8637
8638        // CLI mode should still be functional
8639        assert!(
8640            main_rs.contains("fn cli_mode("),
8641            "main.rs should define cli_mode function"
8642        );
8643        assert!(
8644            main_rs.contains("model.forward"),
8645            "main.rs should call model.forward"
8646        );
8647        assert!(
8648            main_rs.contains("model.forward_prefill"),
8649            "main.rs should call model.forward_prefill"
8650        );
8651    }
8652
8653    // ── Batched prefill tests ──────────────────────────────────────────
8654
8655    #[test]
8656    fn generated_shaders_contain_batch_kernels() {
8657        let shaders = generate_metal_shaders(&minimal_config());
8658
8659        assert!(
8660            shaders.contains("kernel void matmul_vec_batch"),
8661            "shaders should contain matmul_vec_batch kernel"
8662        );
8663        assert!(
8664            shaders.contains("kernel void matmul_vec_q8_batch"),
8665            "shaders should contain matmul_vec_q8_batch kernel"
8666        );
8667        assert!(
8668            shaders.contains("kernel void matmul_q8_gemm_batch"),
8669            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8670        );
8671        assert!(
8672            shaders.contains("kernel void matmul_vec_q4_batch"),
8673            "shaders should contain matmul_vec_q4_batch kernel"
8674        );
8675        assert!(
8676            shaders.contains("kernel void rms_norm_batch"),
8677            "shaders should contain rms_norm_batch kernel"
8678        );
8679        assert!(
8680            shaders.contains("kernel void silu_mul_fused_batch"),
8681            "shaders should contain silu_mul_fused_batch kernel"
8682        );
8683        assert!(
8684            shaders.contains("kernel void add_inplace_batch"),
8685            "shaders should contain add_inplace_batch kernel"
8686        );
8687        assert!(
8688            shaders.contains("kernel void copy_embedding_batch"),
8689            "shaders should contain copy_embedding_batch kernel"
8690        );
8691    }
8692
8693    #[test]
8694    fn generated_model_has_batch_pipelines() {
8695        let config = minimal_config();
8696        let model_rs = generate_model_rs(&config).unwrap();
8697
8698        for pipeline in &[
8699            "matmul_batch_pipeline",
8700            "matmul_q8_batch_pipeline",
8701            "matmul_q8_gemm_batch_pipeline",
8702            "matmul_q4_batch_pipeline",
8703            "rms_norm_batch_pipeline",
8704            "rope_batch_pipeline",
8705            "silu_mul_fused_batch_pipeline",
8706            "add_inplace_batch_pipeline",
8707            "copy_embedding_batch_pipeline",
8708            "attention_batch_pipeline",
8709            "copy_kv_batch_pipeline",
8710        ] {
8711            assert!(
8712                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8713                "MetalModel should have {pipeline} field"
8714            );
8715        }
8716    }
8717
8718    #[test]
8719    fn generated_model_has_batch_buffers() {
8720        let config = minimal_config();
8721        let model_rs = generate_model_rs(&config).unwrap();
8722
8723        for buf in &[
8724            "batch_hidden_buf",
8725            "batch_residual_buf",
8726            "batch_qkv_buf",
8727            "batch_attn_out_buf",
8728            "batch_attn_proj_buf",
8729            "batch_gate_up_buf",
8730            "batch_ffn_hidden_buf",
8731            "batch_ffn_out_buf",
8732            "batch_tokens_buf",
8733            "batch_positions_buf",
8734        ] {
8735            assert!(
8736                model_rs.contains(&format!("{buf}: Buffer")),
8737                "MetalModel should have {buf} field"
8738            );
8739        }
8740    }
8741
8742    #[test]
8743    fn generated_model_has_forward_prefill_batch() {
8744        let config = minimal_config();
8745        let model_rs = generate_model_rs(&config).unwrap();
8746
8747        assert!(
8748            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8749            "MetalModel should have forward_prefill_batch method"
8750        );
8751
8752        // forward_prefill should delegate to forward_prefill_batch
8753        assert!(
8754            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8755            "forward_prefill should delegate to forward_prefill_batch"
8756        );
8757    }
8758
8759    #[test]
8760    fn generated_model_has_max_batch_size_constant() {
8761        let config = minimal_config();
8762        let model_rs = generate_model_rs(&config).unwrap();
8763
8764        assert!(
8765            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8766            "model.rs should define MAX_BATCH_SIZE constant"
8767        );
8768    }
8769
8770    #[test]
8771    fn forward_prefill_batch_uses_batch_dispatch() {
8772        let config = minimal_config();
8773        let model_rs = generate_model_rs(&config).unwrap();
8774
8775        let batch_start = model_rs
8776            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8777            .unwrap();
8778        let batch_body = &model_rs[batch_start..];
8779        let batch_end = batch_body
8780            .find("\n    pub fn reset")
8781            .unwrap_or(batch_body.len());
8782        let batch_code = &batch_body[..batch_end];
8783
8784        // Should use batched dispatch methods
8785        assert!(
8786            batch_code.contains("dispatch_rms_norm_batch"),
8787            "forward_prefill_batch should use dispatch_rms_norm_batch"
8788        );
8789        assert!(
8790            batch_code.contains("dispatch_copy_embedding_batch"),
8791            "forward_prefill_batch should use dispatch_copy_embedding_batch"
8792        );
8793        assert!(
8794            batch_code.contains("dispatch_silu_mul_fused_batch"),
8795            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8796        );
8797        // Should use batched causal attention dispatch
8798        assert!(
8799            batch_code.contains("dispatch_attention_batch"),
8800            "forward_prefill_batch should use dispatch_attention_batch"
8801        );
8802        // Should use fused KV cache copy (both K and V in one dispatch)
8803        assert!(
8804            batch_code.contains("dispatch_copy_kv_both_batch"),
8805            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8806        );
8807        // Should use fused RoPE Q+K dispatch
8808        assert!(
8809            batch_code.contains("dispatch_rope_qk_batch"),
8810            "forward_prefill_batch should use dispatch_rope_qk_batch"
8811        );
8812    }
8813
8814    #[test]
8815    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8816        let config = minimal_q8_config();
8817        let model_rs = generate_model_rs(&config).unwrap();
8818
8819        let batch_start = model_rs
8820            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8821            .unwrap();
8822        let batch_body = &model_rs[batch_start..];
8823        let batch_end = batch_body
8824            .find("\n    pub fn reset")
8825            .unwrap_or(batch_body.len());
8826        let batch_code = &batch_body[..batch_end];
8827
8828        assert!(
8829            batch_code.contains("dispatch_matmul_q8_batch"),
8830            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8831        );
8832    }
8833
8834    #[test]
8835    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8836        let config = minimal_q4_config();
8837        let model_rs = generate_model_rs(&config).unwrap();
8838
8839        let batch_start = model_rs
8840            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8841            .unwrap();
8842        let batch_body = &model_rs[batch_start..];
8843        let batch_end = batch_body
8844            .find("\n    pub fn reset")
8845            .unwrap_or(batch_body.len());
8846        let batch_code = &batch_body[..batch_end];
8847
8848        assert!(
8849            batch_code.contains("dispatch_matmul_q4_batch"),
8850            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8851        );
8852    }
8853
8854    #[test]
8855    fn generated_main_rs_uses_batched_prefill() {
8856        let config = minimal_config();
8857        let main_rs = generate_main_rs("test-model", &config).unwrap();
8858
8859        assert!(
8860            main_rs.contains("forward_prefill_batch"),
8861            "main.rs should use forward_prefill_batch for prompt tokens"
8862        );
8863    }
8864
8865    #[test]
8866    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8867        let config = minimal_config();
8868        let model_rs = generate_model_rs(&config).unwrap();
8869
8870        let batch_start = model_rs
8871            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8872            .unwrap();
8873        let batch_body = &model_rs[batch_start..];
8874        let batch_end = batch_body
8875            .find("\n    pub fn reset")
8876            .unwrap_or(batch_body.len());
8877        let batch_code = &batch_body[..batch_end];
8878
8879        assert!(
8880            batch_code.contains("dispatch_matmul_batch"),
8881            "f32 forward_prefill_batch should use dispatch_matmul_batch"
8882        );
8883        // Should NOT use Q8 or Q4 batch dispatch
8884        assert!(
8885            !batch_code.contains("dispatch_matmul_q8_batch"),
8886            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8887        );
8888        assert!(
8889            !batch_code.contains("dispatch_matmul_q4_batch"),
8890            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8891        );
8892    }
8893
8894    #[test]
8895    fn forward_uses_cpu_embedding_lookup() {
8896        let config = minimal_config();
8897        let model_rs = generate_model_rs(&config).unwrap();
8898
8899        // Find just the forward() body (not forward_profile)
8900        let forward_start = model_rs
8901            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8902            .unwrap();
8903        let forward_body = &model_rs[forward_start..];
8904        let forward_end = forward_body
8905            .find("\n    pub fn forward_profile")
8906            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8907            .unwrap_or(forward_body.len());
8908        let forward_code = &forward_body[..forward_end];
8909
8910        // forward() should use CPU memcpy for embedding lookup (unified memory)
8911        assert!(
8912            forward_code.contains("embed_buf.contents()"),
8913            "forward() should access embed_buf via CPU unified memory for embedding lookup"
8914        );
8915        assert!(
8916            forward_code.contains("copy_nonoverlapping"),
8917            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8918        );
8919        // forward() should NOT use GPU dispatch for embedding
8920        assert!(
8921            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8922            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8923        );
8924    }
8925
8926    #[test]
8927    fn forward_profile_method_exists() {
8928        let config = minimal_config();
8929        let model_rs = generate_model_rs(&config).unwrap();
8930
8931        assert!(
8932            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8933            "MetalModel should have forward_profile() method"
8934        );
8935        // Profile method should print timing information
8936        assert!(
8937            model_rs.contains("[profile]"),
8938            "forward_profile() should print timing with [profile] prefix"
8939        );
8940        assert!(
8941            model_rs.contains("d_embed"),
8942            "forward_profile() should measure embedding time"
8943        );
8944        assert!(
8945            model_rs.contains("d_layers"),
8946            "forward_profile() should measure layer time"
8947        );
8948        assert!(
8949            model_rs.contains("d_logits"),
8950            "forward_profile() should measure logits time"
8951        );
8952    }
8953
8954    #[test]
8955    fn generated_cli_has_profile_flag() {
8956        let config = minimal_config();
8957        let main_rs = generate_main_rs("test-model", &config).unwrap();
8958
8959        assert!(
8960            main_rs.contains("--profile"),
8961            "CLI should support --profile flag"
8962        );
8963        assert!(
8964            main_rs.contains("forward_profile"),
8965            "CLI should call forward_profile when --profile is set"
8966        );
8967    }
8968
8969    #[test]
8970    fn generated_cli_has_thermal_yield() {
8971        let config = minimal_config();
8972        let main_rs = generate_main_rs("test-model", &config).unwrap();
8973
8974        assert!(
8975            main_rs.contains("yield_now()"),
8976            "CLI generation loop should include thread::yield_now() for thermal management"
8977        );
8978    }
8979
8980    // ── Real-world validation tests ──────────────────────────────────────
8981
8982    #[test]
8983    fn generated_forward_handles_single_token_prompt() {
8984        // With a single token (the first prompt token), forward() should work
8985        // at pos=0 where seq_len=1. The attention kernel must handle the case
8986        // where there is only one KV entry (no prefill context).
8987        let config = minimal_config();
8988        let model_rs = generate_model_rs(&config).unwrap();
8989
8990        // The forward function should accept any u32 token_id (no minimum pos guard)
8991        let forward_start = model_rs
8992            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8993            .expect("forward() must exist");
8994        let forward_body = &model_rs[forward_start..forward_start + 400];
8995
8996        // Should NOT require pos > 0 or seq_len > 1
8997        assert!(
8998            !forward_body.contains("assert!(self.pos > 0"),
8999            "forward() must accept pos=0 (first token with no prefill)"
9000        );
9001
9002        // The attention kernel should handle seq_len=1 via the pos field
9003        assert!(
9004            model_rs.contains("self.pos"),
9005            "forward() should use self.pos to track sequence position"
9006        );
9007    }
9008
9009    #[test]
9010    fn generated_reset_clears_kv_cache_position() {
9011        // After reset(), the model should be in a clean state. The pos field
9012        // must be 0 so new generation starts from scratch.
9013        let config = minimal_config();
9014        let model_rs = generate_model_rs(&config).unwrap();
9015
9016        let reset_start = model_rs
9017            .find("pub fn reset(&mut self)")
9018            .expect("reset() must exist");
9019        let reset_body = &model_rs[reset_start..reset_start + 200];
9020
9021        // Reset must zero the position counter
9022        assert!(
9023            reset_body.contains("self.pos = 0"),
9024            "reset() must set self.pos = 0"
9025        );
9026
9027        // Verify reset clears prev_cmd (double-buffering state)
9028        assert!(
9029            reset_body.contains("self.prev_cmd = None"),
9030            "reset() should clear prev_cmd for clean command buffer state"
9031        );
9032    }
9033
9034    #[test]
9035    fn generated_serve_handles_empty_messages_gracefully() {
9036        // The serve endpoint should not crash when receiving an empty messages array.
9037        // The format_chat_messages function should handle this gracefully.
9038        let config = minimal_config();
9039        let main_rs = generate_main_rs("test-model", &config).unwrap();
9040
9041        // The format_chat_messages function should exist and handle empty input
9042        let format_fn_start = main_rs
9043            .find("fn format_chat_messages")
9044            .expect("format_chat_messages must exist");
9045        let format_fn_body =
9046            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
9047
9048        // It should iterate over messages (an empty slice produces an empty loop)
9049        assert!(
9050            format_fn_body.contains("for msg in messages"),
9051            "format_chat_messages should iterate over the messages slice"
9052        );
9053        // It should always append the assistant prompt suffix
9054        assert!(
9055            format_fn_body.contains("<|im_start|>assistant"),
9056            "format_chat_messages should always append assistant prompt header"
9057        );
9058
9059        // The serve function should call model.reset() before each request
9060        let serve_fn_start = main_rs
9061            .find("fn serve(")
9062            .expect("serve function must exist");
9063        let serve_fn_body = &main_rs[serve_fn_start..];
9064        assert!(
9065            serve_fn_body.contains("model.reset()"),
9066            "serve function should reset model between requests"
9067        );
9068    }
9069
9070    #[test]
9071    fn generated_model_forward_increments_pos() {
9072        // Each forward() call must increment self.pos so the next token
9073        // uses the correct RoPE position and KV cache offset.
9074        let config = minimal_config();
9075        let model_rs = generate_model_rs(&config).unwrap();
9076
9077        let forward_start = model_rs
9078            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
9079            .unwrap();
9080        let forward_body = &model_rs[forward_start..];
9081        let forward_end = forward_body
9082            .find("\n    pub fn forward_profile")
9083            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
9084            .or_else(|| forward_body.find("\n    fn dispatch_"))
9085            .unwrap_or(forward_body.len());
9086        let forward_code = &forward_body[..forward_end];
9087
9088        assert!(
9089            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
9090            "forward() must increment self.pos after processing a token"
9091        );
9092    }
9093}