Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    // The attention-scores shared array must be at least effective_seq_len
122    // elements; anything smaller silently overflows for long prompts.  For
123    // small-context models (135M at 2K), a smaller array saves threadgroup
124    // memory and improves occupancy — so we size it precisely.
125    let attn_scores_size = config.max_seq_len.min(4096);
126    r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154    device const float* matrix [[buffer(0)]],
155    device const float* vector [[buffer(1)]],
156    device float* output       [[buffer(2)]],
157    constant uint& rows        [[buffer(3)]],
158    constant uint& cols        [[buffer(4)]],
159    uint tgid [[threadgroup_position_in_grid]],
160    uint tid [[thread_index_in_threadgroup]],
161    uint simd_lane [[thread_index_in_simdgroup]],
162    uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164    // Cooperatively load vector into threadgroup shared memory
165    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166    for (uint i = tid; i < cols; i += 256) {
167        vec_tile[i] = vector[i];
168    }
169    threadgroup_barrier(mem_flags::mem_threadgroup);
170
171    // Each simdgroup handles 8 consecutive rows
172    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173    if (row_base >= rows) return;
174
175    uint base0 = row_base * cols;
176    uint base1 = (row_base + 1) * cols;
177    uint base2 = (row_base + 2) * cols;
178    uint base3 = (row_base + 3) * cols;
179    uint base4 = (row_base + 4) * cols;
180    uint base5 = (row_base + 5) * cols;
181    uint base6 = (row_base + 6) * cols;
182    uint base7 = (row_base + 7) * cols;
183
184    // float4 vectorized accumulation across 8 rows
185    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
186    float4 sum4_0 = float4(0.0f);
187    float4 sum4_1 = float4(0.0f);
188    float4 sum4_2 = float4(0.0f);
189    float4 sum4_3 = float4(0.0f);
190    float4 sum4_4 = float4(0.0f);
191    float4 sum4_5 = float4(0.0f);
192    float4 sum4_6 = float4(0.0f);
193    float4 sum4_7 = float4(0.0f);
194
195    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205    }
206
207    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216    // Handle remaining elements (cols not divisible by 128)
217    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218        float vv = vec_tile[j];
219        sum0 += matrix[base0 + j] * vv;
220        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227    }
228
229    // Simdgroup hardware warp-level reduction
230    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235    // Only first lane writes the results
236    if (simd_lane == 0) {
237        if (row_base     < rows) output[row_base]     = sum0;
238        if (row_base + 1 < rows) output[row_base + 1] = sum1;
239        if (row_base + 2 < rows) output[row_base + 2] = sum2;
240        if (row_base + 3 < rows) output[row_base + 3] = sum3;
241        if (row_base + 4 < rows) output[row_base + 4] = sum4;
242        if (row_base + 5 < rows) output[row_base + 5] = sum5;
243        if (row_base + 6 < rows) output[row_base + 6] = sum6;
244        if (row_base + 7 < rows) output[row_base + 7] = sum7;
245    }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253    device const float* input   [[buffer(0)]],
254    device const float* weight  [[buffer(1)]],
255    device float* output        [[buffer(2)]],
256    constant uint& n            [[buffer(3)]],
257    constant float& eps         [[buffer(4)]],
258    uint tid [[thread_index_in_threadgroup]])
259{
260    // Each thread accumulates partial sum-of-squares
261    float sum_sq = 0.0f;
262    for (uint i = tid; i < n; i += 256) {
263        float v = input[i];
264        sum_sq += v * v;
265    }
266
267    // Simdgroup-level reduction (hardware warp sum)
268    sum_sq = simd_sum(sum_sq);
269
270    // Cross-simdgroup reduction via shared memory
271    threadgroup float shared[8];
272    uint simd_id = tid / 32;
273    uint simd_lane = tid % 32;
274    if (simd_lane == 0) {
275        shared[simd_id] = sum_sq;
276    }
277    threadgroup_barrier(mem_flags::mem_threadgroup);
278
279    // First thread computes final inverse RMS
280    if (tid == 0) {
281        float total = 0.0f;
282        for (uint i = 0; i < 8; i++) {
283            total += shared[i];
284        }
285        shared[0] = fast::rsqrt(total / float(n) + eps);
286    }
287    threadgroup_barrier(mem_flags::mem_threadgroup);
288
289    float inv_rms = shared[0];
290
291    // Normalize
292    for (uint i = tid; i < n; i += 256) {
293        output[i] = input[i] * inv_rms * weight[i];
294    }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301    device float* data        [[buffer(0)]],
302    constant uint& num_heads  [[buffer(1)]],
303    constant uint& head_dim   [[buffer(2)]],
304    constant uint& pos        [[buffer(3)]],
305    constant float& theta     [[buffer(4)]],
306    uint id [[thread_position_in_grid]])
307{
308    uint half_dim = head_dim / 2;
309    uint total_pairs = num_heads * half_dim;
310    if (id >= total_pairs) return;
311
312    uint h = id / half_dim;
313    uint i = id % half_dim;
314    uint off = h * head_dim;
315
316    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317    float angle = float(pos) * freq;
318    float c = cos(angle);
319    float s = sin(angle);
320
321    float x0 = data[off + 2 * i];
322    float x1 = data[off + 2 * i + 1];
323    data[off + 2 * i]     = x0 * c - x1 * s;
324    data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331    device float* data       [[buffer(0)]],
332    constant uint& n         [[buffer(1)]],
333    uint tid [[thread_index_in_threadgroup]],
334    uint tg_size [[threads_per_threadgroup]])
335{
336    threadgroup float shared_val[256];
337
338    // Pass 1: find max
339    float local_max = -INFINITY;
340    for (uint i = tid; i < n; i += tg_size) {
341        local_max = max(local_max, data[i]);
342    }
343    shared_val[tid] = local_max;
344    threadgroup_barrier(mem_flags::mem_threadgroup);
345
346    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347        if (tid < stride) {
348            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349        }
350        threadgroup_barrier(mem_flags::mem_threadgroup);
351    }
352    float max_val = shared_val[0];
353    threadgroup_barrier(mem_flags::mem_threadgroup);
354
355    // Pass 2: exp and sum
356    float local_sum = 0.0f;
357    for (uint i = tid; i < n; i += tg_size) {
358        float e = fast::exp(data[i] - max_val);
359        data[i] = e;
360        local_sum += e;
361    }
362    shared_val[tid] = local_sum;
363    threadgroup_barrier(mem_flags::mem_threadgroup);
364
365    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366        if (tid < stride) {
367            shared_val[tid] += shared_val[tid + stride];
368        }
369        threadgroup_barrier(mem_flags::mem_threadgroup);
370    }
371    float inv_sum = 1.0f / shared_val[0];
372    threadgroup_barrier(mem_flags::mem_threadgroup);
373
374    // Pass 3: normalize
375    for (uint i = tid; i < n; i += tg_size) {
376        data[i] *= inv_sum;
377    }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384    device const float* gate [[buffer(0)]],
385    device const float* up   [[buffer(1)]],
386    device float* output     [[buffer(2)]],
387    constant uint& n         [[buffer(3)]],
388    uint id [[thread_position_in_grid]])
389{
390    if (id >= n) return;
391    float g = gate[id];
392    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397//   gate = gate_up[0..n], up = gate_up[n..2*n]
398//   output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400    device const float* gate_up [[buffer(0)]],
401    device float* output        [[buffer(1)]],
402    constant uint& n            [[buffer(2)]],
403    uint id [[thread_position_in_grid]])
404{
405    if (id >= n) return;
406    float g = gate_up[id];
407    float u = gate_up[n + id];
408    output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414    device const float* a  [[buffer(0)]],
415    device const float* b  [[buffer(1)]],
416    device float* output   [[buffer(2)]],
417    constant uint& n       [[buffer(3)]],
418    uint id [[thread_position_in_grid]])
419{
420    if (id >= n) return;
421    output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428    device const float* src [[buffer(0)]],
429    device float* dst       [[buffer(1)]],
430    constant uint& count    [[buffer(2)]],
431    uint id [[thread_position_in_grid]])
432{
433    if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440    device const float* src     [[buffer(0)]],
441    device float* dst           [[buffer(1)]],
442    constant uint& src_offset   [[buffer(2)]],  // in floats
443    constant uint& count        [[buffer(3)]],
444    uint id [[thread_position_in_grid]])
445{
446    if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── add_inplace ─────────────────────────────────────────────────────────
450// In-place residual connection: a[i] += b[i]
451// Avoids a separate blit copy for residual add, reducing encoder overhead.
452kernel void add_inplace(
453    device float* a        [[buffer(0)]],
454    device const float* b  [[buffer(1)]],
455    constant uint& n       [[buffer(2)]],
456    uint id [[thread_position_in_grid]])
457{
458    if (id >= n) return;
459    a[id] += b[id];
460}
461
462// ── matmul_vec_q8 ─────────────────────────────────────────────────────
463// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
464// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
465// Operates directly on quantized weights to halve memory bandwidth vs f32,
466// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
467//
468// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
469// because int8->float conversion doubles register demand.  Fully unrolled
470// inner loop with float4 vector loads from shared memory eliminates loop
471// overhead and enables better instruction scheduling.
472// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
473constant constexpr uint Q8_ROWS_PER_SG = 4;
474constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
475
476// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
477// doubles ALU work, so the same register budget applies).
478constant constexpr uint Q4_ROWS_PER_SG = 4;
479constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
480
481kernel void matmul_vec_q8(
482    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
483    device const float* vector   [[buffer(1)]],  // f32 input
484    device float* output         [[buffer(2)]],
485    constant uint& rows          [[buffer(3)]],
486    constant uint& cols          [[buffer(4)]],  // number of elements per row
487    uint tgid [[threadgroup_position_in_grid]],
488    uint tid [[thread_index_in_threadgroup]],
489    uint simd_lane [[thread_index_in_simdgroup]],
490    uint simd_id [[simdgroup_index_in_threadgroup]])
491{
492    // Load vector into shared memory
493    threadgroup float vec_tile[VEC_TILE_SIZE];
494    for (uint i = tid; i < cols; i += 256) {
495        vec_tile[i] = vector[i];
496    }
497    threadgroup_barrier(mem_flags::mem_threadgroup);
498
499    // Each simdgroup handles 4 consecutive rows (lower register pressure)
500    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
501    if (row_base >= rows) return;
502
503    // Q8_0: each block is 34 bytes for 32 elements
504    uint blocks_per_row = cols / 32;
505    uint row_bytes = blocks_per_row * 34;
506
507    // Pointers to each row's Q8_0 data
508    device const uchar* r0 = matrix + row_base * row_bytes;
509    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
510    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
511    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
512
513    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
514
515    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
516        uint bb = blk * 34;
517        uint vb = blk * 32;
518
519        // Prefetch all 4 scales
520        float sc0 = float(*(device const half*)(r0 + bb));
521        float sc1 = float(*(device const half*)(r1 + bb));
522        float sc2 = float(*(device const half*)(r2 + bb));
523        float sc3 = float(*(device const half*)(r3 + bb));
524
525        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
526        // Q8_0 block layout where the int8 data starts at offset +2 from a
527        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
528        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
529        // reduction in memory transactions. Metal's char16/packed_char16 are
530        // reserved types and packed_*int4 require >=4-byte alignment which
531        // this layout does not provide, so packed_short4 is the widest valid
532        // vectorized load.
533        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
534        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
535        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
536        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
537
538        // Load all 8 float4 vector values for this 32-element block from shared memory
539        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
540        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
541        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
542        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
543        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
544        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
545        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
546        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
547
548        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
549        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
550        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
551            short4 _s = short4(SHORT4); \
552            char2 _a = as_type<char2>(_s.x); \
553            char2 _b = as_type<char2>(_s.y); \
554            char2 _c = as_type<char2>(_s.z); \
555            char2 _d = as_type<char2>(_s.w); \
556            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
557            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
558        }
559
560        float4 f0, f1;
561        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
562
563        // Row 0: 4 short4 loads cover 32 int8 weights
564        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
565        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
566        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
567        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
568
569        // Row 1
570        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
571        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
572        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
573        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
574
575        // Row 2
576        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
577        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
578        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
579        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
580
581        // Row 3
582        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
583        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
584        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
585        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
586
587        #undef Q8_UNPACK8
588
589        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
590    }
591
592    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
593    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
594
595    if (simd_lane == 0) {
596        if (row_base     < rows) output[row_base]     = sum0;
597        if (row_base + 1 < rows) output[row_base + 1] = sum1;
598        if (row_base + 2 < rows) output[row_base + 2] = sum2;
599        if (row_base + 3 < rows) output[row_base + 3] = sum3;
600    }
601}
602
603// ── matmul_vec_q4 ─────────────────────────────────────────────────────
604// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
605// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
606// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
607// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
608//
609// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
610// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
611kernel void matmul_vec_q4(
612    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
613    device const float* vector   [[buffer(1)]],  // f32 input
614    device float* output         [[buffer(2)]],
615    constant uint& rows          [[buffer(3)]],
616    constant uint& cols          [[buffer(4)]],  // number of elements per row
617    uint tgid [[threadgroup_position_in_grid]],
618    uint tid [[thread_index_in_threadgroup]],
619    uint simd_lane [[thread_index_in_simdgroup]],
620    uint simd_id [[simdgroup_index_in_threadgroup]])
621{
622    // Load vector into shared memory
623    threadgroup float vec_tile[VEC_TILE_SIZE];
624    for (uint i = tid; i < cols; i += 256) {
625        vec_tile[i] = vector[i];
626    }
627    threadgroup_barrier(mem_flags::mem_threadgroup);
628
629    // Each simdgroup handles 4 consecutive rows
630    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
631    if (row_base >= rows) return;
632
633    // Q4_0: each block is 18 bytes for 32 elements
634    uint blocks_per_row = cols / 32;
635    uint row_bytes = blocks_per_row * 18;
636
637    // Pointers to each row's Q4_0 data
638    device const uchar* r0 = matrix + row_base * row_bytes;
639    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
640    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
641    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
642
643    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
644
645    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
646        uint bb = blk * 18;
647        uint vb = blk * 32;
648
649        // Prefetch all 4 scales
650        float sc0 = float(*(device const half*)(r0 + bb));
651        float sc1 = float(*(device const half*)(r1 + bb));
652        float sc2 = float(*(device const half*)(r2 + bb));
653        float sc3 = float(*(device const half*)(r3 + bb));
654
655        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
656        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
657        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
658        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
659        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
660
661        // Load 8 float4 vector values for 32 elements from shared memory
662        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
663        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
664        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
665        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
666        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
667        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
668        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
669        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
670        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
671
672        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
673        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
674        float bd0=0, bd1=0, bd2=0, bd3=0;
675        uchar4 b;
676
677        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
678        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
679                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
680                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
681                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
682        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
683                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
684                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
685                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
686        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
687                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
688                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
689                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
690        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
691                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
692                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
693                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
694
695        // Row 1
696        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
697                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
698                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
699                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
700        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
701                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
702                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
703                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
704        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
705                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
706                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
707                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
708        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
709                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
710                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
711                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
712
713        // Row 2
714        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
715                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
716                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
717                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
718        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
719                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
720                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
721                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
722        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
723                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
724                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
725                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
726        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
727                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
728                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
729                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
730
731        // Row 3
732        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
733                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
734                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
735                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
736        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
737                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
738                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
739                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
740        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
741                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
742                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
743                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
744        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
745                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
746                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
747                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
748
749        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
750    }
751
752    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
753    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
754
755    if (simd_lane == 0) {
756        if (row_base     < rows) output[row_base]     = sum0;
757        if (row_base + 1 < rows) output[row_base + 1] = sum1;
758        if (row_base + 2 < rows) output[row_base + 2] = sum2;
759        if (row_base + 3 < rows) output[row_base + 3] = sum3;
760    }
761}
762
763// ── attention ───────────────────────────────────────────────────────────
764// Single-query attention with simdgroup cooperative reductions.
765// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
766// with simd_max/simd_sum reductions, then weighted sum of V.
767// Each threadgroup handles one head with 256 threads (8 simdgroups).
768//
769// Buffers:
770//   q:       [num_heads * head_dim]       current query
771//   k_cache: [max_seq_len * num_kv_heads * head_dim]
772//   v_cache: [max_seq_len * num_kv_heads * head_dim]
773//   output:  [num_heads * head_dim]
774kernel void attention(
775    device const float* q        [[buffer(0)]],
776    device const float* k_cache  [[buffer(1)]],
777    device const float* v_cache  [[buffer(2)]],
778    device float* output         [[buffer(3)]],
779    constant uint& seq_len       [[buffer(4)]],
780    constant uint& num_heads     [[buffer(5)]],
781    constant uint& num_kv_heads  [[buffer(6)]],
782    constant uint& head_dim      [[buffer(7)]],
783    uint tgid [[threadgroup_position_in_grid]],
784    uint tid [[thread_index_in_threadgroup]],
785    uint simd_lane [[thread_index_in_simdgroup]],
786    uint simd_id [[simdgroup_index_in_threadgroup]])
787{
788    uint head = tgid;
789    if (head >= num_heads) return;
790    uint kv_head = head / (num_heads / num_kv_heads);
791
792    uint q_off = head * head_dim;
793
794    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
795    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
796    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
797    // most generation steps have short effective context.
798    threadgroup float scores[ATTN_SCORES_SIZE];  // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
799
800    for (uint s = simd_id; s < seq_len; s += 8) {
801        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
802        float dot = 0.0;
803        for (uint d = simd_lane; d < head_dim; d += 32) {
804            dot += q[q_off + d] * k_cache[k_off + d];
805        }
806        dot = simd_sum(dot);
807        if (simd_lane == 0) {
808            scores[s] = dot * fast::rsqrt(float(head_dim));
809        }
810    }
811    threadgroup_barrier(mem_flags::mem_threadgroup);
812
813    // Step 2: Softmax over scores (cooperative)
814    // Find max
815    float local_max = -INFINITY;
816    for (uint s = tid; s < seq_len; s += 256) {
817        local_max = max(local_max, scores[s]);
818    }
819    local_max = simd_max(local_max);
820    threadgroup float shared_max[8];
821    if (simd_lane == 0) shared_max[simd_id] = local_max;
822    threadgroup_barrier(mem_flags::mem_threadgroup);
823    if (tid == 0) {
824        float m = shared_max[0];
825        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
826        shared_max[0] = m;
827    }
828    threadgroup_barrier(mem_flags::mem_threadgroup);
829    float max_val = shared_max[0];
830
831    // Exp and sum
832    float local_sum = 0.0;
833    for (uint s = tid; s < seq_len; s += 256) {
834        scores[s] = fast::exp(scores[s] - max_val);
835        local_sum += scores[s];
836    }
837    local_sum = simd_sum(local_sum);
838    threadgroup float shared_sum[8];
839    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
840    threadgroup_barrier(mem_flags::mem_threadgroup);
841    if (tid == 0) {
842        float total = 0.0;
843        for (uint i = 0; i < 8; i++) total += shared_sum[i];
844        shared_sum[0] = 1.0 / total;
845    }
846    threadgroup_barrier(mem_flags::mem_threadgroup);
847    float inv_sum = shared_sum[0];
848
849    for (uint s = tid; s < seq_len; s += 256) {
850        scores[s] *= inv_sum;
851    }
852    threadgroup_barrier(mem_flags::mem_threadgroup);
853
854    // Step 3: Weighted sum of V: output = scores · V
855    // Each thread handles a range of head_dim dimensions.
856    // Process 4 sequence positions at a time for better ILP and reduced
857    // loop overhead (float4 score gather, 4 V loads per iteration).
858    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
859    uint v_stride = num_kv_heads * head_dim;
860    for (uint d = tid; d < head_dim; d += 256) {
861        float acc = 0.0;
862        uint v_base = kv_head * head_dim + d;
863        for (uint s = 0; s < seq_len4; s += 4) {
864            float sc0 = scores[s];
865            float sc1 = scores[s + 1];
866            float sc2 = scores[s + 2];
867            float sc3 = scores[s + 3];
868            acc += sc0 * v_cache[s * v_stride + v_base]
869                 + sc1 * v_cache[(s+1) * v_stride + v_base]
870                 + sc2 * v_cache[(s+2) * v_stride + v_base]
871                 + sc3 * v_cache[(s+3) * v_stride + v_base];
872        }
873        for (uint s = seq_len4; s < seq_len; s++) {
874            acc += scores[s] * v_cache[s * v_stride + v_base];
875        }
876        output[q_off + d] = acc;
877    }
878}
879
880// ── Batched prefill kernels ────────────────────────────────────────────
881// These kernels process M input vectors against the same weight matrix
882// in a single dispatch, converting mat-vec into mat-mat for better GPU
883// utilization during prompt prefill.
884
885// ── rms_norm_batch ─────────────────────────────────────────────────────
886// RMS normalization for a batch of vectors.
887// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
888// Grid: M threadgroups (one per token).
889kernel void rms_norm_batch(
890    device const float* input   [[buffer(0)]],  // [M, n]
891    device const float* weight  [[buffer(1)]],  // [n]
892    device float* output        [[buffer(2)]],  // [M, n]
893    constant uint& n            [[buffer(3)]],
894    constant float& eps         [[buffer(4)]],
895    constant uint& num_tokens   [[buffer(5)]],
896    uint tgid [[threadgroup_position_in_grid]],
897    uint tid [[thread_index_in_threadgroup]])
898{
899    if (tgid >= num_tokens) return;
900
901    uint base = tgid * n;
902
903    float sum_sq = 0.0f;
904    for (uint i = tid; i < n; i += 256) {
905        float v = input[base + i];
906        sum_sq += v * v;
907    }
908
909    sum_sq = simd_sum(sum_sq);
910
911    threadgroup float shared[8];
912    uint simd_id = tid / 32;
913    uint simd_lane = tid % 32;
914    if (simd_lane == 0) {
915        shared[simd_id] = sum_sq;
916    }
917    threadgroup_barrier(mem_flags::mem_threadgroup);
918
919    if (tid == 0) {
920        float total = 0.0f;
921        for (uint i = 0; i < 8; i++) {
922            total += shared[i];
923        }
924        shared[0] = fast::rsqrt(total / float(n) + eps);
925    }
926    threadgroup_barrier(mem_flags::mem_threadgroup);
927
928    float inv_rms = shared[0];
929
930    for (uint i = tid; i < n; i += 256) {
931        output[base + i] = input[base + i] * inv_rms * weight[i];
932    }
933}
934
935// ── rope_batch ─────────────────────────────────────────────────────────
936// Rotary Position Embedding for a batch of vectors with different positions.
937// data layout: [M, num_heads * head_dim], positions: [M]
938// Each thread handles one (token, head, pair) combination.
939kernel void rope_batch(
940    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
941    constant uint& num_heads     [[buffer(1)]],
942    constant uint& head_dim      [[buffer(2)]],
943    device const uint* positions  [[buffer(3)]],  // [M] position per token
944    constant float& theta        [[buffer(4)]],
945    constant uint& num_tokens    [[buffer(5)]],
946    uint id [[thread_position_in_grid]])
947{
948    uint half_dim = head_dim / 2;
949    uint pairs_per_token = num_heads * half_dim;
950    uint total = num_tokens * pairs_per_token;
951    if (id >= total) return;
952
953    uint token = id / pairs_per_token;
954    uint rem = id % pairs_per_token;
955    uint h = rem / half_dim;
956    uint i = rem % half_dim;
957    uint off = token * (num_heads * head_dim) + h * head_dim;
958
959    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
960    float angle = float(positions[token]) * freq;
961    float c = cos(angle);
962    float s = sin(angle);
963
964    float x0 = data[off + 2 * i];
965    float x1 = data[off + 2 * i + 1];
966    data[off + 2 * i]     = x0 * c - x1 * s;
967    data[off + 2 * i + 1] = x0 * s + x1 * c;
968}
969
970// ── silu_mul_fused_batch ───────────────────────────────────────────────
971// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
972// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
973kernel void silu_mul_fused_batch(
974    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
975    device float* output        [[buffer(1)]],  // [M, n]
976    constant uint& n            [[buffer(2)]],
977    constant uint& num_tokens   [[buffer(3)]],
978    uint id [[thread_position_in_grid]])
979{
980    uint total = num_tokens * n;
981    if (id >= total) return;
982    uint token = id / n;
983    uint i = id % n;
984    uint gu_base = token * 2 * n;
985    float g = gate_up[gu_base + i];
986    float u = gate_up[gu_base + n + i];
987    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
988}
989
990// ── add_inplace_batch ──────────────────────────────────────────────────
991// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
992kernel void add_inplace_batch(
993    device float* a        [[buffer(0)]],  // [M * n]
994    device const float* b  [[buffer(1)]],  // [M * n]
995    constant uint& total   [[buffer(2)]],  // M * n
996    uint id [[thread_position_in_grid]])
997{
998    if (id >= total) return;
999    a[id] += b[id];
1000}
1001
1002// ── copy_embedding_batch ───────────────────────────────────────────────
1003// Copy M embedding rows from embedding table to a contiguous batch buffer.
1004// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1005kernel void copy_embedding_batch(
1006    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1007    device float* output        [[buffer(1)]],  // [M, dim]
1008    device const uint* tokens   [[buffer(2)]],  // [M]
1009    constant uint& dim          [[buffer(3)]],
1010    constant uint& num_tokens   [[buffer(4)]],
1011    uint id [[thread_position_in_grid]])
1012{
1013    uint total = num_tokens * dim;
1014    if (id >= total) return;
1015    uint token_idx = id / dim;
1016    uint d = id % dim;
1017    output[id] = embed[tokens[token_idx] * dim + d];
1018}
1019
1020// ── matmul_vec_batch ───────────────────────────────────────────────────
1021// Batched matrix-vector multiply: process M input vectors against the same
1022// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1023// Each threadgroup handles one (token, row_group) pair.
1024kernel void matmul_vec_batch(
1025    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1026    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1027    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1028    constant uint& num_tokens   [[buffer(3)]],  // M
1029    constant uint& rows         [[buffer(4)]],
1030    constant uint& cols         [[buffer(5)]],
1031    uint tgid [[threadgroup_position_in_grid]],
1032    uint tid [[thread_index_in_threadgroup]],
1033    uint simd_lane [[thread_index_in_simdgroup]],
1034    uint simd_id [[simdgroup_index_in_threadgroup]])
1035{
1036    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1037    uint token = tgid / row_tgs;
1038    uint tg_in_token = tgid % row_tgs;
1039    if (token >= num_tokens) return;
1040
1041    // Load this token's input vector into shared memory
1042    threadgroup float vec_tile[VEC_TILE_SIZE];
1043    device const float* input = inputs + token * cols;
1044    for (uint i = tid; i < cols; i += 256) {
1045        vec_tile[i] = input[i];
1046    }
1047    threadgroup_barrier(mem_flags::mem_threadgroup);
1048
1049    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1050    if (row_base >= rows) return;
1051
1052    uint base0 = row_base * cols;
1053    uint base1 = (row_base + 1) * cols;
1054    uint base2 = (row_base + 2) * cols;
1055    uint base3 = (row_base + 3) * cols;
1056    uint base4 = (row_base + 4) * cols;
1057    uint base5 = (row_base + 5) * cols;
1058    uint base6 = (row_base + 6) * cols;
1059    uint base7 = (row_base + 7) * cols;
1060
1061    uint cols_vec4 = cols & ~127u;
1062    float4 sum4_0 = float4(0.0f);
1063    float4 sum4_1 = float4(0.0f);
1064    float4 sum4_2 = float4(0.0f);
1065    float4 sum4_3 = float4(0.0f);
1066    float4 sum4_4 = float4(0.0f);
1067    float4 sum4_5 = float4(0.0f);
1068    float4 sum4_6 = float4(0.0f);
1069    float4 sum4_7 = float4(0.0f);
1070
1071    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1072        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1073        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1074        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1075        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1076        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1077        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1078        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1079        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1080        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1081    }
1082
1083    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1084    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1085    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1086    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1087    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1088    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1089    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1090    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1091
1092    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1093        float vv = vec_tile[j];
1094        sum0 += matrix[base0 + j] * vv;
1095        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1096        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1097        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1098        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1099        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1100        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1101        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1102    }
1103
1104    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1105    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1106    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1107    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1108
1109    device float* output = outputs + token * rows;
1110    if (simd_lane == 0) {
1111        if (row_base     < rows) output[row_base]     = sum0;
1112        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1113        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1114        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1115        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1116        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1117        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1118        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1119    }
1120}
1121
1122// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1123// Batched Q8_0 matrix-vector multiply for M input vectors.
1124// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1125kernel void matmul_vec_q8_batch(
1126    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1127    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1128    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1129    constant uint& num_tokens    [[buffer(3)]],  // M
1130    constant uint& rows          [[buffer(4)]],
1131    constant uint& cols          [[buffer(5)]],
1132    uint tgid [[threadgroup_position_in_grid]],
1133    uint tid [[thread_index_in_threadgroup]],
1134    uint simd_lane [[thread_index_in_simdgroup]],
1135    uint simd_id [[simdgroup_index_in_threadgroup]])
1136{
1137    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1138    uint token = tgid / row_tgs;
1139    uint tg_in_token = tgid % row_tgs;
1140    if (token >= num_tokens) return;
1141
1142    // Load this token's input vector into shared memory
1143    threadgroup float vec_tile[VEC_TILE_SIZE];
1144    device const float* input = inputs + token * cols;
1145    for (uint i = tid; i < cols; i += 256) {
1146        vec_tile[i] = input[i];
1147    }
1148    threadgroup_barrier(mem_flags::mem_threadgroup);
1149
1150    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1151    if (row_base >= rows) return;
1152
1153    uint blocks_per_row = cols / 32;
1154    uint row_bytes = blocks_per_row * 34;
1155
1156    device const uchar* r0 = matrix + row_base * row_bytes;
1157    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1158    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1159    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1160
1161    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1162
1163    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1164        uint bb = blk * 34;
1165        uint vb = blk * 32;
1166
1167        float sc0 = float(*(device const half*)(r0 + bb));
1168        float sc1 = float(*(device const half*)(r1 + bb));
1169        float sc2 = float(*(device const half*)(r2 + bb));
1170        float sc3 = float(*(device const half*)(r3 + bb));
1171
1172        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1173        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1174        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1175        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1176        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1177        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1178
1179        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1180        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1181        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1182        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1183        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1184        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1185        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1186        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1187
1188        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1189            short4 _s = short4(SHORT4); \
1190            char2 _a = as_type<char2>(_s.x); \
1191            char2 _b = as_type<char2>(_s.y); \
1192            char2 _c = as_type<char2>(_s.z); \
1193            char2 _d = as_type<char2>(_s.w); \
1194            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1195            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1196        }
1197
1198        float4 f0, f1;
1199        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1200
1201        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1202        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1203        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1204        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1205
1206        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1207        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1208        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1209        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1210
1211        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1212        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1213        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1214        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1215
1216        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1217        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1218        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1219        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1220
1221        #undef Q8_UNPACK8
1222
1223        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1224    }
1225
1226    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1227    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1228
1229    device float* output = outputs + token * rows;
1230    if (simd_lane == 0) {
1231        if (row_base     < rows) output[row_base]     = sum0;
1232        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1233        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1234        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1235    }
1236}
1237
1238// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1239// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1240// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1241// the Q8_0 weight blocks are fetched once from device memory and reused for
1242// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1243// per-token dispatch).
1244//
1245// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1246// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1247// with simd_sum.  Token vectors are read directly from device memory inside
1248// the block loop (not cached in shared memory) so intermediate_size up to
1249// 8192 fits without spilling threadgroup memory.
1250constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1251
1252kernel void matmul_q8_gemm_batch(
1253    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1254    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1255    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1256    constant uint& num_tokens    [[buffer(3)]],  // M
1257    constant uint& rows          [[buffer(4)]],
1258    constant uint& cols          [[buffer(5)]],
1259    uint2 tgid [[threadgroup_position_in_grid]],
1260    uint tid [[thread_index_in_threadgroup]],
1261    uint simd_lane [[thread_index_in_simdgroup]],
1262    uint simd_id [[simdgroup_index_in_threadgroup]])
1263{
1264    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1265    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1266    if (row_base >= rows || tok_base >= num_tokens) return;
1267
1268    // How many tokens in this tile are valid?
1269    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1270
1271    uint blocks_per_row = cols / 32;
1272    uint row_bytes = blocks_per_row * 34;
1273
1274    device const uchar* r0 = matrix + row_base * row_bytes;
1275    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1276    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1277    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1278
1279    // Accumulators: 4 tokens × 4 rows per simdgroup.
1280    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1281    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1282    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1283    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1284
1285    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1286        uint bb = blk * 34;
1287        uint vb = blk * 32;
1288
1289        // ── Load weight data ONCE per block (reused across all tokens) ──
1290        float sc0 = float(*(device const half*)(r0 + bb));
1291        float sc1 = float(*(device const half*)(r1 + bb));
1292        float sc2 = float(*(device const half*)(r2 + bb));
1293        float sc3 = float(*(device const half*)(r3 + bb));
1294
1295        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1296        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1297        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1298        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1299
1300        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1301            short4 _s = short4(SHORT4); \
1302            char2 _a = as_type<char2>(_s.x); \
1303            char2 _b = as_type<char2>(_s.y); \
1304            char2 _c = as_type<char2>(_s.z); \
1305            char2 _d = as_type<char2>(_s.w); \
1306            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1307            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1308        }
1309
1310        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1311        // registers for the duration of the block and are dotted against
1312        // every token's vector tile.
1313        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1314        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1315        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1316        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1317
1318        Q8_UNPACK8(d0[0], w0_0, w0_1);
1319        Q8_UNPACK8(d0[1], w0_2, w0_3);
1320        Q8_UNPACK8(d0[2], w0_4, w0_5);
1321        Q8_UNPACK8(d0[3], w0_6, w0_7);
1322
1323        Q8_UNPACK8(d1[0], w1_0, w1_1);
1324        Q8_UNPACK8(d1[1], w1_2, w1_3);
1325        Q8_UNPACK8(d1[2], w1_4, w1_5);
1326        Q8_UNPACK8(d1[3], w1_6, w1_7);
1327
1328        Q8_UNPACK8(d2[0], w2_0, w2_1);
1329        Q8_UNPACK8(d2[1], w2_2, w2_3);
1330        Q8_UNPACK8(d2[2], w2_4, w2_5);
1331        Q8_UNPACK8(d2[3], w2_6, w2_7);
1332
1333        Q8_UNPACK8(d3[0], w3_0, w3_1);
1334        Q8_UNPACK8(d3[1], w3_2, w3_3);
1335        Q8_UNPACK8(d3[2], w3_4, w3_5);
1336        Q8_UNPACK8(d3[3], w3_6, w3_7);
1337
1338        #undef Q8_UNPACK8
1339
1340        // ── For each token, read vector and accumulate against shared weights ──
1341        // Token 0 (always valid: tok_count >= 1).
1342        {
1343            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1344            float4 v0 = *(device const float4*)(a0);
1345            float4 v1 = *(device const float4*)(a0 + 4);
1346            float4 v2 = *(device const float4*)(a0 + 8);
1347            float4 v3 = *(device const float4*)(a0 + 12);
1348            float4 v4 = *(device const float4*)(a0 + 16);
1349            float4 v5 = *(device const float4*)(a0 + 20);
1350            float4 v6 = *(device const float4*)(a0 + 24);
1351            float4 v7 = *(device const float4*)(a0 + 28);
1352            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1353                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1354            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1355                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1356            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1357                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1358            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1359                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1360            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1361        }
1362        // Token 1
1363        if (tok_count > 1) {
1364            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1365            float4 v0 = *(device const float4*)(a1);
1366            float4 v1 = *(device const float4*)(a1 + 4);
1367            float4 v2 = *(device const float4*)(a1 + 8);
1368            float4 v3 = *(device const float4*)(a1 + 12);
1369            float4 v4 = *(device const float4*)(a1 + 16);
1370            float4 v5 = *(device const float4*)(a1 + 20);
1371            float4 v6 = *(device const float4*)(a1 + 24);
1372            float4 v7 = *(device const float4*)(a1 + 28);
1373            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1374                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1375            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1376                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1377            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1378                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1379            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1380                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1381            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1382        }
1383        // Token 2
1384        if (tok_count > 2) {
1385            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1386            float4 v0 = *(device const float4*)(a2);
1387            float4 v1 = *(device const float4*)(a2 + 4);
1388            float4 v2 = *(device const float4*)(a2 + 8);
1389            float4 v3 = *(device const float4*)(a2 + 12);
1390            float4 v4 = *(device const float4*)(a2 + 16);
1391            float4 v5 = *(device const float4*)(a2 + 20);
1392            float4 v6 = *(device const float4*)(a2 + 24);
1393            float4 v7 = *(device const float4*)(a2 + 28);
1394            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1395                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1396            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1397                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1398            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1399                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1400            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1401                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1402            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1403        }
1404        // Token 3
1405        if (tok_count > 3) {
1406            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1407            float4 v0 = *(device const float4*)(a3);
1408            float4 v1 = *(device const float4*)(a3 + 4);
1409            float4 v2 = *(device const float4*)(a3 + 8);
1410            float4 v3 = *(device const float4*)(a3 + 12);
1411            float4 v4 = *(device const float4*)(a3 + 16);
1412            float4 v5 = *(device const float4*)(a3 + 20);
1413            float4 v6 = *(device const float4*)(a3 + 24);
1414            float4 v7 = *(device const float4*)(a3 + 28);
1415            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1416                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1417            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1418                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1419            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1420                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1421            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1422                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1423            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1424        }
1425    }
1426
1427    // simdgroup reduction
1428    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1429    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1430    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1431    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1432
1433    if (simd_lane == 0) {
1434        device float* o0 = outputs + (tok_base + 0) * rows;
1435        if (row_base     < rows) o0[row_base]     = s00;
1436        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1437        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1438        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1439
1440        if (tok_count > 1) {
1441            device float* o1 = outputs + (tok_base + 1) * rows;
1442            if (row_base     < rows) o1[row_base]     = s10;
1443            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1444            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1445            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1446        }
1447        if (tok_count > 2) {
1448            device float* o2 = outputs + (tok_base + 2) * rows;
1449            if (row_base     < rows) o2[row_base]     = s20;
1450            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1451            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1452            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1453        }
1454        if (tok_count > 3) {
1455            device float* o3 = outputs + (tok_base + 3) * rows;
1456            if (row_base     < rows) o3[row_base]     = s30;
1457            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1458            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1459            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1460        }
1461    }
1462}
1463
1464// ── matmul_q8_mma ──────────────────────────────────────────────────────
1465// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1466// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1467// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1468// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1469//
1470// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1471// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1472// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1473// dequantized into threadgroup memory once per block and reused by all
1474// simdgroups in the tile.
1475//
1476// Assumptions (verified in the dispatch helper, falls back otherwise):
1477//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1478//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1479//   * num_tokens may be any value; partial row at the tile boundary is handled
1480//     via a scratch copy path.
1481constant constexpr uint MMA_TOK_TILE = 16;
1482constant constexpr uint MMA_ROW_TILE = 16;
1483
1484kernel void matmul_q8_mma(
1485    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1486    device const float* inputs   [[buffer(1)]],  // [M, cols]
1487    device float* outputs        [[buffer(2)]],  // [M, rows]
1488    constant uint& num_tokens    [[buffer(3)]],
1489    constant uint& rows          [[buffer(4)]],
1490    constant uint& cols          [[buffer(5)]],
1491    uint2 tgid [[threadgroup_position_in_grid]],
1492    uint tid [[thread_index_in_threadgroup]],
1493    uint simd_id [[simdgroup_index_in_threadgroup]])
1494{
1495    uint row_base = tgid.x * MMA_ROW_TILE;
1496    uint tok_base = tgid.y * MMA_TOK_TILE;
1497    if (row_base >= rows || tok_base >= num_tokens) return;
1498
1499    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1500    threadgroup float w_tile[MMA_ROW_TILE * 32];
1501    threadgroup float t_tile[MMA_TOK_TILE * 32];
1502
1503    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1504    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1505    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1506
1507    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1508
1509    uint blocks_per_row = cols / 32;
1510    uint row_bytes = blocks_per_row * 34;
1511
1512    for (uint blk = 0; blk < blocks_per_row; blk++) {
1513        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1514        // 512 floats / 128 threads = 4 floats per thread.
1515        {
1516            uint base = tid * 4;
1517            for (uint ii = 0; ii < 4; ii++) {
1518                uint idx = base + ii;
1519                uint r = idx / 32;
1520                uint k = idx % 32;
1521                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1522                float sc = float(*(device const half*)rp);
1523                int ival = int(*(device const int8_t*)(rp + 2 + k));
1524                w_tile[r * 32 + k] = float(ival) * sc;
1525            }
1526        }
1527
1528        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1529        {
1530            uint base = tid * 4;
1531            for (uint ii = 0; ii < 4; ii++) {
1532                uint idx = base + ii;
1533                uint m = idx / 32;
1534                uint k = idx % 32;
1535                uint tok = tok_base + m;
1536                t_tile[m * 32 + k] = (tok < num_tokens)
1537                    ? inputs[tok * cols + blk * 32 + k]
1538                    : 0.0f;
1539            }
1540        }
1541
1542        threadgroup_barrier(mem_flags::mem_threadgroup);
1543
1544        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1545        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1546        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1547        // C[m, r] += A[m, k] * B[k, r]
1548        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1549            simdgroup_matrix<float, 8, 8> A, B;
1550            simdgroup_load(A,
1551                t_tile + sg_tok_base * 32 + k_sub * 8,
1552                32,
1553                ulong2(0, 0),
1554                false);
1555            simdgroup_load(B,
1556                w_tile + sg_row_base * 32 + k_sub * 8,
1557                32,
1558                ulong2(0, 0),
1559                true);
1560            simdgroup_multiply_accumulate(C, A, B, C);
1561        }
1562
1563        threadgroup_barrier(mem_flags::mem_threadgroup);
1564    }
1565
1566    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1567    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1568    uint out_tok = tok_base + sg_tok_base;
1569    uint out_row = row_base + sg_row_base;
1570    bool full_tok = (out_tok + 8 <= num_tokens);
1571    if (full_tok) {
1572        // Fast path: entire 8×8 sub-tile is in-bounds.
1573        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1574    } else if (out_tok < num_tokens) {
1575        // Partial row at the last token tile: stage in per-simdgroup scratch
1576        // and scalar-copy the valid rows.
1577        threadgroup float scratch[4 * 64];
1578        simdgroup_store(C, scratch + simd_id * 64, 8);
1579        simdgroup_barrier(mem_flags::mem_threadgroup);
1580        uint lane = tid % 32;
1581        if (lane == 0) {
1582            uint valid = num_tokens - out_tok;  // 1..7
1583            for (uint m = 0; m < valid; m++) {
1584                device float* dst = outputs + (out_tok + m) * rows + out_row;
1585                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1586                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1587                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1588            }
1589        }
1590    }
1591}
1592
1593// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1594// Larger-tile variant of matmul_q8_mma for long-context prefill.
1595//
1596// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1597// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1598// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1599//
1600//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1601//     output sub-tiles (tok, row):
1602//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1603//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1604//
1605// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1606// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1607// kernel — and halves the number of threadgroups vs the 16×16 tile.
1608//
1609// Assumptions (verified in dispatch helper, fallback otherwise):
1610//   * cols % 32 == 0
1611//   * rows % 32 == 0
1612constant constexpr uint MMA32_TOK_TILE = 32;
1613constant constexpr uint MMA32_ROW_TILE = 32;
1614
1615kernel void matmul_q8_mma32(
1616    device const uchar* matrix   [[buffer(0)]],
1617    device const float* inputs   [[buffer(1)]],
1618    device float* outputs        [[buffer(2)]],
1619    constant uint& num_tokens    [[buffer(3)]],
1620    constant uint& rows          [[buffer(4)]],
1621    constant uint& cols          [[buffer(5)]],
1622    uint2 tgid [[threadgroup_position_in_grid]],
1623    uint tid [[thread_index_in_threadgroup]],
1624    uint simd_id [[simdgroup_index_in_threadgroup]])
1625{
1626    uint row_base = tgid.x * MMA32_ROW_TILE;
1627    uint tok_base = tgid.y * MMA32_TOK_TILE;
1628    if (row_base >= rows || tok_base >= num_tokens) return;
1629
1630    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1631    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1632    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1633
1634    uint sg_tok_idx  = simd_id / 2;      // 0..3
1635    uint sg_row_half = simd_id % 2;      // 0..1
1636    uint sg_tok_base = sg_tok_idx * 8;
1637    uint sg_row_base_a = sg_row_half * 16 + 0;
1638    uint sg_row_base_b = sg_row_half * 16 + 8;
1639
1640    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1641    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1642
1643    uint blocks_per_row = cols / 32;
1644    uint row_bytes = blocks_per_row * 34;
1645
1646    for (uint blk = 0; blk < blocks_per_row; blk++) {
1647        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1648        {
1649            uint base = tid * 4;
1650            for (uint ii = 0; ii < 4; ii++) {
1651                uint idx = base + ii;
1652                uint r = idx / 32;
1653                uint k = idx % 32;
1654                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1655                float sc = float(*(device const half*)rp);
1656                int ival = int(*(device const int8_t*)(rp + 2 + k));
1657                w_tile[r * 32 + k] = float(ival) * sc;
1658            }
1659        }
1660
1661        // Cooperative token tile load.
1662        {
1663            uint base = tid * 4;
1664            for (uint ii = 0; ii < 4; ii++) {
1665                uint idx = base + ii;
1666                uint m = idx / 32;
1667                uint k = idx % 32;
1668                uint tok = tok_base + m;
1669                t_tile[m * 32 + k] = (tok < num_tokens)
1670                    ? inputs[tok * cols + blk * 32 + k]
1671                    : 0.0f;
1672            }
1673        }
1674
1675        threadgroup_barrier(mem_flags::mem_threadgroup);
1676
1677        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1678        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1679            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1680            simdgroup_load(A,
1681                t_tile + sg_tok_base * 32 + k_sub * 8,
1682                32,
1683                ulong2(0, 0),
1684                false);
1685            simdgroup_load(B_a,
1686                w_tile + sg_row_base_a * 32 + k_sub * 8,
1687                32,
1688                ulong2(0, 0),
1689                true);
1690            simdgroup_load(B_b,
1691                w_tile + sg_row_base_b * 32 + k_sub * 8,
1692                32,
1693                ulong2(0, 0),
1694                true);
1695            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1696            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1697        }
1698
1699        threadgroup_barrier(mem_flags::mem_threadgroup);
1700    }
1701
1702    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1703    // (verified in dispatch), so full simdgroup_store is safe for the row
1704    // dimension; only the last token tile may be partial.
1705    uint out_tok = tok_base + sg_tok_base;
1706    uint out_row_a = row_base + sg_row_base_a;
1707    uint out_row_b = row_base + sg_row_base_b;
1708    bool full_tok = (out_tok + 8 <= num_tokens);
1709    if (full_tok) {
1710        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1711        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1712    } else if (out_tok < num_tokens) {
1713        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1714        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1715        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1716        simdgroup_barrier(mem_flags::mem_threadgroup);
1717        uint lane = tid % 32;
1718        if (lane == 0) {
1719            uint valid = num_tokens - out_tok;  // 1..7
1720            for (uint m = 0; m < valid; m++) {
1721                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1722                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1723                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1724                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1725                for (uint j = 0; j < 8; j++) {
1726                    dst_a[j] = src_a[j];
1727                    dst_b[j] = src_b[j];
1728                }
1729            }
1730        }
1731    }
1732}
1733
1734// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1735// FP16 threadgroup-tile variant of matmul_q8_mma32.
1736//
1737// Stores dequantized weights and token inputs as `half` in threadgroup
1738// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1739// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1740// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1741// representation preserves the full quantized dynamic range.  Token
1742// activations stay numerically safe because the subsequent
1743// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1744//
1745// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1746// accumulators each.  Primary win vs mma32 is occupancy at moderate
1747// prefill lengths where the GPU is wave-starved.
1748kernel void matmul_q8_mma32_h(
1749    device const uchar* matrix   [[buffer(0)]],
1750    device const float* inputs   [[buffer(1)]],
1751    device float* outputs        [[buffer(2)]],
1752    constant uint& num_tokens    [[buffer(3)]],
1753    constant uint& rows          [[buffer(4)]],
1754    constant uint& cols          [[buffer(5)]],
1755    uint2 tgid [[threadgroup_position_in_grid]],
1756    uint tid [[thread_index_in_threadgroup]],
1757    uint simd_id [[simdgroup_index_in_threadgroup]])
1758{
1759    uint row_base = tgid.x * MMA32_ROW_TILE;
1760    uint tok_base = tgid.y * MMA32_TOK_TILE;
1761    if (row_base >= rows || tok_base >= num_tokens) return;
1762
1763    // 32×32 half tiles — 2 KB each, 4 KB total.
1764    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1765    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1766
1767    uint sg_tok_idx  = simd_id / 2;
1768    uint sg_row_half = simd_id % 2;
1769    uint sg_tok_base = sg_tok_idx * 8;
1770    uint sg_row_base_a = sg_row_half * 16 + 0;
1771    uint sg_row_base_b = sg_row_half * 16 + 8;
1772
1773    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1774    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1775
1776    uint blocks_per_row = cols / 32;
1777    uint row_bytes = blocks_per_row * 34;
1778
1779    for (uint blk = 0; blk < blocks_per_row; blk++) {
1780        // Cooperative weight dequantization to FP16.
1781        {
1782            uint base = tid * 4;
1783            for (uint ii = 0; ii < 4; ii++) {
1784                uint idx = base + ii;
1785                uint r = idx / 32;
1786                uint k = idx % 32;
1787                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1788                float sc = float(*(device const half*)rp);
1789                int ival = int(*(device const int8_t*)(rp + 2 + k));
1790                w_tile[r * 32 + k] = half(float(ival) * sc);
1791            }
1792        }
1793
1794        // Cooperative token tile load (f32 → f16 narrowing).
1795        {
1796            uint base = tid * 4;
1797            for (uint ii = 0; ii < 4; ii++) {
1798                uint idx = base + ii;
1799                uint m = idx / 32;
1800                uint k = idx % 32;
1801                uint tok = tok_base + m;
1802                t_tile[m * 32 + k] = (tok < num_tokens)
1803                    ? half(inputs[tok * cols + blk * 32 + k])
1804                    : half(0);
1805            }
1806        }
1807
1808        threadgroup_barrier(mem_flags::mem_threadgroup);
1809
1810        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1811            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1812            simdgroup_load(A,
1813                t_tile + sg_tok_base * 32 + k_sub * 8,
1814                32,
1815                ulong2(0, 0),
1816                false);
1817            simdgroup_load(B_a,
1818                w_tile + sg_row_base_a * 32 + k_sub * 8,
1819                32,
1820                ulong2(0, 0),
1821                true);
1822            simdgroup_load(B_b,
1823                w_tile + sg_row_base_b * 32 + k_sub * 8,
1824                32,
1825                ulong2(0, 0),
1826                true);
1827            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1828            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1829        }
1830
1831        threadgroup_barrier(mem_flags::mem_threadgroup);
1832    }
1833
1834    uint out_tok = tok_base + sg_tok_base;
1835    uint out_row_a = row_base + sg_row_base_a;
1836    uint out_row_b = row_base + sg_row_base_b;
1837    bool full_tok = (out_tok + 8 <= num_tokens);
1838    if (full_tok) {
1839        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1840        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1841    } else if (out_tok < num_tokens) {
1842        threadgroup float scratch[8 * 2 * 64];
1843        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1844        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1845        simdgroup_barrier(mem_flags::mem_threadgroup);
1846        uint lane = tid % 32;
1847        if (lane == 0) {
1848            uint valid = num_tokens - out_tok;
1849            for (uint m = 0; m < valid; m++) {
1850                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1851                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1852                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1853                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1854                for (uint j = 0; j < 8; j++) {
1855                    dst_a[j] = src_a[j];
1856                    dst_b[j] = src_b[j];
1857                }
1858            }
1859        }
1860    }
1861}
1862
1863// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1864// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1865//
1866// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1867// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1868//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1869//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1870// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1871//
1872// Per K_sub iteration: load two A fragments and two B fragments, then run
1873// **four** MMA instructions reusing A_top with both B's and A_bot with
1874// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1875// 2-accumulator kernel and halves the thread count per threadgroup (128
1876// threads), which often improves occupancy on Apple GPUs where the
1877// concurrent-thread budget is the tighter limit than shared-memory size.
1878kernel void matmul_q8_mma32_h4(
1879    device const uchar* matrix   [[buffer(0)]],
1880    device const float* inputs   [[buffer(1)]],
1881    device float* outputs        [[buffer(2)]],
1882    constant uint& num_tokens    [[buffer(3)]],
1883    constant uint& rows          [[buffer(4)]],
1884    constant uint& cols          [[buffer(5)]],
1885    uint2 tgid [[threadgroup_position_in_grid]],
1886    uint tid [[thread_index_in_threadgroup]],
1887    uint simd_id [[simdgroup_index_in_threadgroup]])
1888{
1889    uint row_base = tgid.x * MMA32_ROW_TILE;
1890    uint tok_base = tgid.y * MMA32_TOK_TILE;
1891    if (row_base >= rows || tok_base >= num_tokens) return;
1892
1893    // 32×32 FP16 tiles, 4 KB total.
1894    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1895    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1896
1897    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1898    uint sg_tok_q = simd_id / 2;   // 0..1
1899    uint sg_row_q = simd_id % 2;   // 0..1
1900    uint sg_tok_base = sg_tok_q * 16;
1901    uint sg_row_base = sg_row_q * 16;
1902
1903    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1904    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1905    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1906    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1907
1908    uint blocks_per_row = cols / 32;
1909    uint row_bytes = blocks_per_row * 34;
1910
1911    for (uint blk = 0; blk < blocks_per_row; blk++) {
1912        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1913        {
1914            uint base = tid * 8;
1915            for (uint ii = 0; ii < 8; ii++) {
1916                uint idx = base + ii;
1917                uint r = idx / 32;
1918                uint k = idx % 32;
1919                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1920                float sc = float(*(device const half*)rp);
1921                int ival = int(*(device const int8_t*)(rp + 2 + k));
1922                w_tile[r * 32 + k] = half(float(ival) * sc);
1923            }
1924        }
1925
1926        // Cooperative token tile load.
1927        {
1928            uint base = tid * 8;
1929            for (uint ii = 0; ii < 8; ii++) {
1930                uint idx = base + ii;
1931                uint m = idx / 32;
1932                uint k = idx % 32;
1933                uint tok = tok_base + m;
1934                t_tile[m * 32 + k] = (tok < num_tokens)
1935                    ? half(inputs[tok * cols + blk * 32 + k])
1936                    : half(0);
1937            }
1938        }
1939
1940        threadgroup_barrier(mem_flags::mem_threadgroup);
1941
1942        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1943        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1944            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1945            simdgroup_load(A_top,
1946                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1947                32,
1948                ulong2(0, 0),
1949                false);
1950            simdgroup_load(A_bot,
1951                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1952                32,
1953                ulong2(0, 0),
1954                false);
1955            simdgroup_load(B_lo,
1956                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1957                32,
1958                ulong2(0, 0),
1959                true);
1960            simdgroup_load(B_hi,
1961                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1962                32,
1963                ulong2(0, 0),
1964                true);
1965            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1966            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1967            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1968            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1969        }
1970
1971        threadgroup_barrier(mem_flags::mem_threadgroup);
1972    }
1973
1974    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
1975    uint out_tok_top = tok_base + sg_tok_base + 0;
1976    uint out_tok_bot = tok_base + sg_tok_base + 8;
1977    uint out_row_lo  = row_base + sg_row_base + 0;
1978    uint out_row_hi  = row_base + sg_row_base + 8;
1979    bool full = (out_tok_bot + 8 <= num_tokens);
1980    if (full) {
1981        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1982        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1983        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1984        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1985    } else {
1986        // Partial-token fallback via per-simdgroup scratch.
1987        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
1988        uint sg_base = simd_id * 256;
1989        simdgroup_store(C_00, scratch + sg_base +   0, 8);
1990        simdgroup_store(C_01, scratch + sg_base +  64, 8);
1991        simdgroup_store(C_10, scratch + sg_base + 128, 8);
1992        simdgroup_store(C_11, scratch + sg_base + 192, 8);
1993        simdgroup_barrier(mem_flags::mem_threadgroup);
1994        uint lane = tid % 32;
1995        if (lane == 0) {
1996            for (uint m = 0; m < 8; m++) {
1997                uint t_top = out_tok_top + m;
1998                if (t_top < num_tokens) {
1999                    device float* dst0 = outputs + t_top * rows + out_row_lo;
2000                    device float* dst1 = outputs + t_top * rows + out_row_hi;
2001                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
2002                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
2003                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2004                }
2005                uint t_bot = out_tok_bot + m;
2006                if (t_bot < num_tokens) {
2007                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2008                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2009                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2010                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2011                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2012                }
2013            }
2014        }
2015    }
2016}
2017
2018// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2019// All-half MMA variant of matmul_q8_mma32_h4.
2020//
2021// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2022// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2023// rate (dual-issue FMA), so if Q8_0 precision holds through half
2024// accumulation this kernel can double the effective FLOP throughput on
2025// matmul-bound prefill.
2026//
2027// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2028// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2029// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2030// precision on extreme values, but the inputs are already quantized so the
2031// per-product error floor is higher than the half-precision rounding error.
2032// We verify correctness on 135M / 1B / 3B before enabling.
2033kernel void matmul_q8_mma32_hh4(
2034    device const uchar* matrix   [[buffer(0)]],
2035    device const float* inputs   [[buffer(1)]],
2036    device float* outputs        [[buffer(2)]],
2037    constant uint& num_tokens    [[buffer(3)]],
2038    constant uint& rows          [[buffer(4)]],
2039    constant uint& cols          [[buffer(5)]],
2040    uint2 tgid [[threadgroup_position_in_grid]],
2041    uint tid [[thread_index_in_threadgroup]],
2042    uint simd_id [[simdgroup_index_in_threadgroup]])
2043{
2044    uint row_base = tgid.x * MMA32_ROW_TILE;
2045    uint tok_base = tgid.y * MMA32_TOK_TILE;
2046    if (row_base >= rows || tok_base >= num_tokens) return;
2047
2048    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2049    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2050
2051    uint sg_tok_q = simd_id / 2;
2052    uint sg_row_q = simd_id % 2;
2053    uint sg_tok_base = sg_tok_q * 16;
2054    uint sg_row_base = sg_row_q * 16;
2055
2056    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2057    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2058    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2059    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2060
2061    uint blocks_per_row = cols / 32;
2062    uint row_bytes = blocks_per_row * 34;
2063
2064    for (uint blk = 0; blk < blocks_per_row; blk++) {
2065        {
2066            uint base = tid * 8;
2067            for (uint ii = 0; ii < 8; ii++) {
2068                uint idx = base + ii;
2069                uint r = idx / 32;
2070                uint k = idx % 32;
2071                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2072                float sc = float(*(device const half*)rp);
2073                int ival = int(*(device const int8_t*)(rp + 2 + k));
2074                w_tile[r * 32 + k] = half(float(ival) * sc);
2075            }
2076        }
2077        {
2078            uint base = tid * 8;
2079            for (uint ii = 0; ii < 8; ii++) {
2080                uint idx = base + ii;
2081                uint m = idx / 32;
2082                uint k = idx % 32;
2083                uint tok = tok_base + m;
2084                t_tile[m * 32 + k] = (tok < num_tokens)
2085                    ? half(inputs[tok * cols + blk * 32 + k])
2086                    : half(0);
2087            }
2088        }
2089        threadgroup_barrier(mem_flags::mem_threadgroup);
2090
2091        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2092            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2093            simdgroup_load(A_top,
2094                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2095                32, ulong2(0, 0), false);
2096            simdgroup_load(A_bot,
2097                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2098                32, ulong2(0, 0), false);
2099            simdgroup_load(B_lo,
2100                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2101                32, ulong2(0, 0), true);
2102            simdgroup_load(B_hi,
2103                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2104                32, ulong2(0, 0), true);
2105            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2106            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2107            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2108            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2109        }
2110
2111        threadgroup_barrier(mem_flags::mem_threadgroup);
2112    }
2113
2114    // Store half accumulators via scratch (must widen to f32 for device output).
2115    uint out_tok_top = tok_base + sg_tok_base + 0;
2116    uint out_tok_bot = tok_base + sg_tok_base + 8;
2117    uint out_row_lo  = row_base + sg_row_base + 0;
2118    uint out_row_hi  = row_base + sg_row_base + 8;
2119
2120    threadgroup half scratch[4 * 4 * 64];
2121    uint sg_base = simd_id * 256;
2122    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2123    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2124    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2125    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2126    simdgroup_barrier(mem_flags::mem_threadgroup);
2127    uint lane = tid % 32;
2128    if (lane == 0) {
2129        for (uint m = 0; m < 8; m++) {
2130            uint t_top = out_tok_top + m;
2131            if (t_top < num_tokens) {
2132                device float* dst0 = outputs + t_top * rows + out_row_lo;
2133                device float* dst1 = outputs + t_top * rows + out_row_hi;
2134                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2135                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2136                for (uint j = 0; j < 8; j++) {
2137                    dst0[j] = float(src0[j]);
2138                    dst1[j] = float(src1[j]);
2139                }
2140            }
2141            uint t_bot = out_tok_bot + m;
2142            if (t_bot < num_tokens) {
2143                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2144                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2145                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2146                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2147                for (uint j = 0; j < 8; j++) {
2148                    dst2[j] = float(src2[j]);
2149                    dst3[j] = float(src3[j]);
2150                }
2151            }
2152        }
2153    }
2154}
2155
2156// ── add_bias_batch ─────────────────────────────────────────────────────
2157// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2158// Used for Qwen2 QKV bias after the fused qkv matmul.
2159//     out[token, i] += bias[i]    for i in 0..rows, token in 0..num_tokens
2160kernel void add_bias_batch(
2161    device float* out            [[buffer(0)]],  // [num_tokens, rows]
2162    device const float* bias     [[buffer(1)]],  // [rows]
2163    constant uint& num_tokens    [[buffer(2)]],
2164    constant uint& rows          [[buffer(3)]],
2165    uint id [[thread_position_in_grid]])
2166{
2167    uint total = num_tokens * rows;
2168    if (id >= total) return;
2169    uint i = id % rows;
2170    out[id] += bias[i];
2171}
2172
2173// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2174// Batched Q4_0 matrix-vector multiply for M input vectors.
2175// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2176kernel void matmul_vec_q4_batch(
2177    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2178    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2179    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2180    constant uint& num_tokens    [[buffer(3)]],  // M
2181    constant uint& rows          [[buffer(4)]],
2182    constant uint& cols          [[buffer(5)]],
2183    uint tgid [[threadgroup_position_in_grid]],
2184    uint tid [[thread_index_in_threadgroup]],
2185    uint simd_lane [[thread_index_in_simdgroup]],
2186    uint simd_id [[simdgroup_index_in_threadgroup]])
2187{
2188    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2189    uint token = tgid / row_tgs;
2190    uint tg_in_token = tgid % row_tgs;
2191    if (token >= num_tokens) return;
2192
2193    threadgroup float vec_tile[VEC_TILE_SIZE];
2194    device const float* input = inputs + token * cols;
2195    for (uint i = tid; i < cols; i += 256) {
2196        vec_tile[i] = input[i];
2197    }
2198    threadgroup_barrier(mem_flags::mem_threadgroup);
2199
2200    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2201    if (row_base >= rows) return;
2202
2203    uint blocks_per_row = cols / 32;
2204    uint row_bytes = blocks_per_row * 18;
2205
2206    device const uchar* r0 = matrix + row_base * row_bytes;
2207    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2208    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2209    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2210
2211    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2212
2213    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2214        uint bb = blk * 18;
2215        uint vb = blk * 32;
2216
2217        float sc0 = float(*(device const half*)(r0 + bb));
2218        float sc1 = float(*(device const half*)(r1 + bb));
2219        float sc2 = float(*(device const half*)(r2 + bb));
2220        float sc3 = float(*(device const half*)(r3 + bb));
2221
2222        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2223        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2224        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2225        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2226
2227        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2228        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2229        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2230        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2231        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2232        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2233        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2234        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2235
2236        float bd0=0, bd1=0, bd2=0, bd3=0;
2237        uchar4 b;
2238
2239        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2240        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2241        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2242        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2243
2244        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2245        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2246        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2247        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2248
2249        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2250        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2251        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2252        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2253
2254        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2255        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2256        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2257        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2258
2259        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2260    }
2261
2262    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2263    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2264
2265    device float* output = outputs + token * rows;
2266    if (simd_lane == 0) {
2267        if (row_base     < rows) output[row_base]     = sum0;
2268        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2269        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2270        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2271    }
2272}
2273
2274// ── copy_kv_batch ─────────────────────────────────────────────────────
2275// Copy K or V from a strided batch QKV buffer to the KV cache.
2276// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2277// dst layout: contiguous [max_seq, kv_dim] cache.
2278kernel void copy_kv_batch(
2279    device const float* src  [[buffer(0)]],  // batch QKV buffer
2280    device float* dst        [[buffer(1)]],  // KV cache
2281    constant uint& M         [[buffer(2)]],  // num tokens in batch
2282    constant uint& kv_dim    [[buffer(3)]],  // floats per KV vector
2283    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2284    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2285    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2286    uint id [[thread_position_in_grid]])
2287{
2288    uint total = M * kv_dim;
2289    if (id >= total) return;
2290    uint token = id / kv_dim;
2291    uint d = id % kv_dim;
2292    uint dst_off = (base_pos + token) * kv_dim + d;
2293    uint src_off = token * src_stride + src_offset + d;
2294    dst[dst_off] = src[src_off];
2295}
2296
2297// ── attention_batch ───────────────────────────────────────────────────
2298// Batched causal attention for prefill. Processes M tokens in one dispatch.
2299// Each threadgroup handles one (token, head) pair.
2300// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2301// Causal masking: token i can only attend to positions 0..base_pos+i.
2302kernel void attention_batch(
2303    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2304    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2305    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2306    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2307    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2308    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2309    constant uint& num_heads         [[buffer(6)]],
2310    constant uint& num_kv_heads      [[buffer(7)]],
2311    constant uint& head_dim          [[buffer(8)]],
2312    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2313    uint tgid [[threadgroup_position_in_grid]],
2314    uint tid [[thread_index_in_threadgroup]],
2315    uint simd_lane [[thread_index_in_simdgroup]],
2316    uint simd_id [[simdgroup_index_in_threadgroup]])
2317{
2318    // Grid: M * num_heads threadgroups
2319    uint token_idx = tgid / num_heads;
2320    uint head = tgid % num_heads;
2321    if (token_idx >= M) return;
2322
2323    uint kv_head = head / (num_heads / num_kv_heads);
2324    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2325
2326    // Q offset uses strided layout (from batch QKV buffer)
2327    uint q_off = token_idx * q_stride + head * head_dim;
2328    // Output is contiguous [M, num_heads * head_dim]
2329    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2330
2331    // Shared memory for attention scores — sized to the effective max_seq_len
2332    // (4096 for all supported models) so long-context attention doesn't overflow.
2333    threadgroup float scores[ATTN_SCORES_SIZE];
2334
2335    // Step 1: Q * K^T with simdgroup reduction
2336    for (uint s = simd_id; s < seq_len; s += 8) {
2337        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2338        float dot = 0.0;
2339        for (uint d = simd_lane; d < head_dim; d += 32) {
2340            dot += q_batch[q_off + d] * k_cache[k_off + d];
2341        }
2342        dot = simd_sum(dot);
2343        if (simd_lane == 0) {
2344            scores[s] = dot * fast::rsqrt(float(head_dim));
2345        }
2346    }
2347    threadgroup_barrier(mem_flags::mem_threadgroup);
2348
2349    // Step 2: Softmax (cooperative)
2350    float local_max = -INFINITY;
2351    for (uint s = tid; s < seq_len; s += 256) {
2352        local_max = max(local_max, scores[s]);
2353    }
2354    local_max = simd_max(local_max);
2355    threadgroup float shared_max[8];
2356    if (simd_lane == 0) shared_max[simd_id] = local_max;
2357    threadgroup_barrier(mem_flags::mem_threadgroup);
2358    if (tid == 0) {
2359        float m = shared_max[0];
2360        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2361        shared_max[0] = m;
2362    }
2363    threadgroup_barrier(mem_flags::mem_threadgroup);
2364    float max_val = shared_max[0];
2365
2366    float local_sum = 0.0;
2367    for (uint s = tid; s < seq_len; s += 256) {
2368        scores[s] = fast::exp(scores[s] - max_val);
2369        local_sum += scores[s];
2370    }
2371    local_sum = simd_sum(local_sum);
2372    threadgroup float shared_sum[8];
2373    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2374    threadgroup_barrier(mem_flags::mem_threadgroup);
2375    if (tid == 0) {
2376        float total = 0.0;
2377        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2378        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2379    }
2380    threadgroup_barrier(mem_flags::mem_threadgroup);
2381    float inv_sum = shared_sum[0];
2382    for (uint s = tid; s < seq_len; s += 256) {
2383        scores[s] *= inv_sum;
2384    }
2385    threadgroup_barrier(mem_flags::mem_threadgroup);
2386
2387    // Step 3: scores * V using float4 vectorized loads
2388    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2389    // This is much better than the scalar version where only 64 of 256 threads are active.
2390    uint v_stride = num_kv_heads * head_dim;
2391    uint head_dim4 = head_dim / 4;
2392    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2393        uint d = d4 * 4;
2394        float4 acc = float4(0.0);
2395        uint v_base = kv_head * head_dim + d;
2396        uint seq_len4 = seq_len & ~3u;
2397        for (uint s = 0; s < seq_len4; s += 4) {
2398            float sc0 = scores[s];
2399            float sc1 = scores[s + 1];
2400            float sc2 = scores[s + 2];
2401            float sc3 = scores[s + 3];
2402            acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
2403                 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
2404                 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
2405                 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
2406        }
2407        for (uint s = seq_len4; s < seq_len; s++) {
2408            acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
2409        }
2410        *(device float4*)(output_batch + out_off + d) = acc;
2411    }
2412    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2413    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2414        float acc = 0.0;
2415        uint v_base = kv_head * head_dim + d;
2416        for (uint s = 0; s < seq_len; s++) {
2417            acc += scores[s] * v_cache[s * v_stride + v_base];
2418        }
2419        output_batch[out_off + d] = acc;
2420    }
2421}
2422
2423// ── attention_flash_batch ─────────────────────────────────────────────
2424// Streaming attention with online softmax.  Same grid as attention_batch
2425// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2426// matrix is never materialized.  K/V positions are processed in a tile of
2427// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2428// the standard flash-attention recurrence:
2429//
2430//     m_new   = max(m_old, tile_max)
2431//     alpha   = exp(m_old - m_new)
2432//     l_new   = alpha * l_old + sum(exp(S - m_new))
2433//     O_new   = alpha * O_old + sum(exp(S - m_new) * V)
2434//     O_final = O / l
2435//
2436// This removes the `scores[2048]` cap in attention_batch (which silently
2437// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2438// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2439//
2440// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2441constant constexpr uint FLASH_K_TILE = 32;
2442constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2443
2444kernel void attention_flash_batch(
2445    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2446    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2447    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2448    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2449    constant uint& M                 [[buffer(4)]],
2450    constant uint& base_pos          [[buffer(5)]],
2451    constant uint& num_heads         [[buffer(6)]],
2452    constant uint& num_kv_heads      [[buffer(7)]],
2453    constant uint& head_dim          [[buffer(8)]],
2454    constant uint& q_stride          [[buffer(9)]],
2455    uint tgid [[threadgroup_position_in_grid]],
2456    uint tid [[thread_index_in_threadgroup]],
2457    uint simd_lane [[thread_index_in_simdgroup]],
2458    uint simd_id [[simdgroup_index_in_threadgroup]])
2459{
2460    uint token_idx = tgid / num_heads;
2461    uint head = tgid % num_heads;
2462    if (token_idx >= M) return;
2463
2464    uint kv_head = head / (num_heads / num_kv_heads);
2465    uint seq_len = base_pos + token_idx + 1;  // causal: attend to [0, base_pos + token_idx]
2466    uint q_off = token_idx * q_stride + head * head_dim;
2467    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2468
2469    // Threadgroup state:
2470    //   q_sh:      Q vector for this (token, head), loaded once
2471    //   o_sh:      running output vector, updated each K tile
2472    //   scores_sh: scores for the current K tile only
2473    //   stats:     [running max, running sum]  (see flash-attention recurrence)
2474    threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2475    threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2476    threadgroup float scores_sh[FLASH_K_TILE];
2477    threadgroup float stats[2];
2478    threadgroup float sg_scratch[8];  // simdgroup-level reduction buffer
2479
2480    // --- Load Q (one row) and zero the running O ---
2481    for (uint d = tid; d < head_dim; d += 256) {
2482        q_sh[d] = q_batch[q_off + d];
2483        o_sh[d] = 0.0f;
2484    }
2485    if (tid == 0) {
2486        stats[0] = -INFINITY;
2487        stats[1] = 0.0f;
2488    }
2489    threadgroup_barrier(mem_flags::mem_threadgroup);
2490
2491    float scale = fast::rsqrt(float(head_dim));
2492    uint v_stride = num_kv_heads * head_dim;
2493    uint v_base = kv_head * head_dim;
2494
2495    // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2496    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2497        uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2498
2499        // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2500        // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2501        for (uint ti = simd_id; ti < tile_n; ti += 8) {
2502            uint k_off = (kv_base + ti) * v_stride + v_base;  // same layout as V stride
2503            float dot = 0.0f;
2504            for (uint d = simd_lane; d < head_dim; d += 32) {
2505                dot += q_sh[d] * k_cache[k_off + d];
2506            }
2507            dot = simd_sum(dot);
2508            if (simd_lane == 0) {
2509                scores_sh[ti] = dot * scale;
2510            }
2511        }
2512        threadgroup_barrier(mem_flags::mem_threadgroup);
2513
2514        // [2] Tile max via cooperative reduction.
2515        float local_max = -INFINITY;
2516        for (uint s = tid; s < tile_n; s += 256) {
2517            local_max = max(local_max, scores_sh[s]);
2518        }
2519        local_max = simd_max(local_max);
2520        if (simd_lane == 0) {
2521            sg_scratch[simd_id] = local_max;
2522        }
2523        threadgroup_barrier(mem_flags::mem_threadgroup);
2524        // [3] Merge with running max, compute alpha, rescale running l.
2525        float m_new;
2526        float alpha;
2527        if (tid == 0) {
2528            float tile_max = sg_scratch[0];
2529            for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2530            float m_old = stats[0];
2531            m_new = max(m_old, tile_max);
2532            // First iteration: m_old = -inf → alpha = 0 (reset O).
2533            alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2534            stats[0] = m_new;
2535            stats[1] *= alpha;
2536            // Broadcast via sg_scratch.
2537            sg_scratch[0] = alpha;
2538            sg_scratch[1] = m_new;
2539        }
2540        threadgroup_barrier(mem_flags::mem_threadgroup);
2541        alpha = sg_scratch[0];
2542        m_new = sg_scratch[1];
2543
2544        // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2545        for (uint d = tid; d < head_dim; d += 256) {
2546            o_sh[d] *= alpha;
2547        }
2548        for (uint s = tid; s < tile_n; s += 256) {
2549            scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2550        }
2551        threadgroup_barrier(mem_flags::mem_threadgroup);
2552
2553        // [5] Tile sum → update running l.
2554        float local_sum = 0.0f;
2555        for (uint s = tid; s < tile_n; s += 256) {
2556            local_sum += scores_sh[s];
2557        }
2558        local_sum = simd_sum(local_sum);
2559        if (simd_lane == 0) {
2560            sg_scratch[simd_id] = local_sum;
2561        }
2562        threadgroup_barrier(mem_flags::mem_threadgroup);
2563        if (tid == 0) {
2564            float tile_sum = 0.0f;
2565            for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2566            stats[1] += tile_sum;
2567        }
2568        threadgroup_barrier(mem_flags::mem_threadgroup);
2569
2570        // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2571        for (uint d = tid; d < head_dim; d += 256) {
2572            float acc = 0.0f;
2573            for (uint s = 0; s < tile_n; s++) {
2574                acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2575            }
2576            o_sh[d] += acc;
2577        }
2578        threadgroup_barrier(mem_flags::mem_threadgroup);
2579    }
2580
2581    // --- Normalize and write output ---
2582    float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2583    for (uint d = tid; d < head_dim; d += 256) {
2584        output_batch[out_off + d] = o_sh[d] * inv_l;
2585    }
2586}
2587
2588// ── attention_mma_flash_batch ─────────────────────────────────────────
2589// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2590// both Q·K^T and P·V.  Processes Q_BLOCK=8 tokens of one head per
2591// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2592// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2593// arithmetic.
2594//
2595// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2596// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2597// to attention_batch / attention_flash_batch otherwise.
2598//
2599// Online softmax recurrence is identical to attention_flash_batch but
2600// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2601constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2602constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2603constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2604
2605kernel void attention_mma_flash_batch(
2606    device const float* q_batch      [[buffer(0)]],
2607    device const float* k_cache      [[buffer(1)]],
2608    device const float* v_cache      [[buffer(2)]],
2609    device float* output_batch       [[buffer(3)]],
2610    constant uint& M                 [[buffer(4)]],
2611    constant uint& base_pos          [[buffer(5)]],
2612    constant uint& num_heads         [[buffer(6)]],
2613    constant uint& num_kv_heads      [[buffer(7)]],
2614    constant uint& head_dim          [[buffer(8)]],
2615    constant uint& q_stride          [[buffer(9)]],
2616    uint2 tgid [[threadgroup_position_in_grid]],
2617    uint tid [[thread_index_in_threadgroup]],
2618    uint simd_lane [[thread_index_in_simdgroup]],
2619    uint simd_id [[simdgroup_index_in_threadgroup]])
2620{
2621    uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2622    uint head = tgid.y;
2623    if (q_block_start >= M) return;
2624    uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2625
2626    uint kv_head = head / (num_heads / num_kv_heads);
2627    // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2628    // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2629    uint seq_len = base_pos + q_block_start + q_valid;
2630    float scale = fast::rsqrt(float(head_dim));
2631
2632    uint kv_stride = num_kv_heads * head_dim;
2633    uint kv_base_off = kv_head * head_dim;
2634
2635    // ── Threadgroup memory ──
2636    // q_sh:  [Q_BLOCK, head_dim] half  — Q tile, loaded once
2637    // k_sh:  [K_BLOCK, head_dim] half  — K tile, refreshed per kv_base iter
2638    // v_sh:  [K_BLOCK, head_dim] half  — V tile, refreshed per kv_base iter
2639    // s_sh:  [Q_BLOCK, K_BLOCK] float  — raw Q·K^T scores, then scaled+masked
2640    // p_sh:  [Q_BLOCK, K_BLOCK] half   — softmax probabilities (for P·V MMA)
2641    // o_sh:  [Q_BLOCK, head_dim] float — running output accumulator
2642    // m_sh:  [Q_BLOCK] float           — running max per Q row
2643    // l_sh:  [Q_BLOCK] float           — running softmax denominator per Q row
2644    // scratch: 4*Q_BLOCK floats        — per-row reduction scratch
2645    threadgroup half  q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2646    threadgroup half  k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2647    threadgroup half  v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2648    threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2649    threadgroup half  p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2650    threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2651    threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2652    threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2653    threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2654
2655    // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2656    uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2657    for (uint i = tid; i < qblock_elems; i += 128) {
2658        uint q = i / head_dim;
2659        uint d = i % head_dim;
2660        if (q < q_valid) {
2661            uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2662            q_sh[q * head_dim + d] = half(q_batch[q_off]);
2663        } else {
2664            q_sh[q * head_dim + d] = half(0);
2665        }
2666        o_sh[q * head_dim + d] = 0.0f;
2667    }
2668    if (tid < FLASH_MMA_Q_BLOCK) {
2669        m_sh[tid] = -INFINITY;
2670        l_sh[tid] = 0.0f;
2671    }
2672    threadgroup_barrier(mem_flags::mem_threadgroup);
2673
2674    // ── Stream K/V in K_BLOCK chunks ──
2675    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2676        uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2677
2678        // Load K and V tile into TG memory (as half).
2679        uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2680        for (uint i = tid; i < kv_tile_elems; i += 128) {
2681            uint k_pos = i / head_dim;
2682            uint d = i % head_dim;
2683            if (k_pos < tile_n) {
2684                uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2685                k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2686                v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2687            } else {
2688                k_sh[k_pos * head_dim + d] = half(0);
2689                v_sh[k_pos * head_dim + d] = half(0);
2690            }
2691        }
2692        threadgroup_barrier(mem_flags::mem_threadgroup);
2693
2694        // ── Phase 1: S = Q @ K^T via MMA ──
2695        // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2696        // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2697        // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2698        {
2699            simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2700            uint dim_chunks = head_dim / 8;
2701            for (uint dc = 0; dc < dim_chunks; dc++) {
2702                simdgroup_matrix<half, 8, 8> A, B;
2703                // A = Q[0:8, dc*8 : dc*8+8]  (rows of Q, no transpose)
2704                simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2705                // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2706                // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2707                // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2708                // transpose=true, which places it in the register as K^T of that sub-block.
2709                simdgroup_load(B,
2710                    k_sh + (simd_id * 8) * head_dim + dc * 8,
2711                    head_dim, ulong2(0, 0), true);
2712                simdgroup_multiply_accumulate(C, A, B, C);
2713            }
2714            // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2715            simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2716        }
2717        threadgroup_barrier(mem_flags::mem_threadgroup);
2718
2719        // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2720        // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2721        uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2722        for (uint i = tid; i < s_elems; i += 128) {
2723            uint q = i / FLASH_MMA_K_BLOCK;
2724            uint k = i % FLASH_MMA_K_BLOCK;
2725            uint global_q = q_block_start + q;
2726            uint global_kv = kv_base + k;
2727            bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2728            s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2729        }
2730        threadgroup_barrier(mem_flags::mem_threadgroup);
2731
2732        // ── Phase 2b: per-row max via simdgroup reduction ──
2733        // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2734        // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2735        {
2736            uint row_base = simd_id * 2;
2737            for (uint qr = 0; qr < 2; qr++) {
2738                uint q = row_base + qr;
2739                float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2740                float row_max = simd_max(my);
2741                if (simd_lane == 0) {
2742                    scratch[q] = row_max;  // tile_max[q]
2743                }
2744            }
2745        }
2746        threadgroup_barrier(mem_flags::mem_threadgroup);
2747
2748        // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2749        if (tid < FLASH_MMA_Q_BLOCK) {
2750            uint q = tid;
2751            float m_old = m_sh[q];
2752            float tile_max = scratch[q];
2753            float m_new = max(m_old, tile_max);
2754            float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2755            m_sh[q] = m_new;
2756            l_sh[q] = l_sh[q] * alpha;
2757            // scratch[q]               = m_new   (for phase 2d)
2758            // scratch[Q_BLOCK + q]     = alpha   (for phase 3)
2759            scratch[q] = m_new;
2760            scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2761        }
2762        threadgroup_barrier(mem_flags::mem_threadgroup);
2763
2764        // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2765        {
2766            uint row_base = simd_id * 2;
2767            for (uint qr = 0; qr < 2; qr++) {
2768                uint q = row_base + qr;
2769                float m_new = scratch[q];
2770                float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2771                p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2772                float row_sum = simd_sum(p);
2773                if (simd_lane == 0) {
2774                    scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2775                }
2776            }
2777        }
2778        threadgroup_barrier(mem_flags::mem_threadgroup);
2779
2780        // ── Phase 2e: l_sh += tile_sum ──
2781        if (tid < FLASH_MMA_Q_BLOCK) {
2782            uint q = tid;
2783            l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2784        }
2785
2786        // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2787        threadgroup_barrier(mem_flags::mem_threadgroup);
2788        for (uint i = tid; i < qblock_elems; i += 128) {
2789            uint q = i / head_dim;
2790            float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2791            o_sh[i] *= alpha;
2792        }
2793        threadgroup_barrier(mem_flags::mem_threadgroup);
2794
2795        // ── Phase 4: O += P @ V via MMA ──
2796        // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2797        // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2798        // For head_dim=64: 16 dims/sg = 2 tiles.  head_dim=128: 32 dims/sg = 4 tiles.
2799        {
2800            uint dims_per_sg = head_dim / 4;       // 16 or 32
2801            uint tiles_per_sg = dims_per_sg / 8;   // 2 or 4
2802            uint sg_d_base = simd_id * dims_per_sg;
2803            for (uint t = 0; t < tiles_per_sg; t++) {
2804                uint d_base = sg_d_base + t * 8;
2805                simdgroup_matrix<float, 8, 8> O_acc;
2806                simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2807                uint k_chunks = FLASH_MMA_K_BLOCK / 8;  // 4
2808                for (uint kc = 0; kc < k_chunks; kc++) {
2809                    simdgroup_matrix<half, 8, 8> A, B;
2810                    simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2811                                   ulong2(0, 0), false);
2812                    simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2813                                   ulong2(0, 0), false);
2814                    simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2815                }
2816                simdgroup_store(O_acc, o_sh + d_base, head_dim);
2817            }
2818        }
2819        threadgroup_barrier(mem_flags::mem_threadgroup);
2820    }
2821
2822    // ── Finalize: O /= l, write to output ──
2823    for (uint i = tid; i < qblock_elems; i += 128) {
2824        uint q = i / head_dim;
2825        uint d = i % head_dim;
2826        if (q < q_valid) {
2827            float l = l_sh[q];
2828            float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2829            uint token_idx = q_block_start + q;
2830            uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2831            output_batch[out_off] = o_sh[i] * inv_l;
2832        }
2833    }
2834}
2835
2836// ── rope_qk_batch ─────────────────────────────────────────────────────
2837// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2838// launch + memory barrier per layer. Both Q and K live in the same
2839// qkv_data buffer at different offsets within each token's row.
2840// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2841kernel void rope_qk_batch(
2842    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2843    constant uint& M                 [[buffer(1)]],   // num tokens
2844    constant uint& base_pos          [[buffer(2)]],   // starting position
2845    constant uint& num_q_heads       [[buffer(3)]],
2846    constant uint& num_kv_heads      [[buffer(4)]],
2847    constant uint& head_dim          [[buffer(5)]],
2848    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2849    constant float& theta            [[buffer(7)]],
2850    uint id [[thread_position_in_grid]])
2851{
2852    uint half_dim = head_dim / 2;
2853    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2854    uint token = id / total_pairs;
2855    uint pair = id % total_pairs;
2856    if (token >= M) return;
2857
2858    uint pos = base_pos + token;
2859    uint q_pairs = num_q_heads * half_dim;
2860
2861    uint h, i, offset;
2862    if (pair < q_pairs) {
2863        // Q head
2864        h = pair / half_dim;
2865        i = pair % half_dim;
2866        offset = token * qkv_stride + h * head_dim + i * 2;
2867    } else {
2868        // K head
2869        uint kp = pair - q_pairs;
2870        h = kp / half_dim;
2871        i = kp % half_dim;
2872        uint k_start = num_q_heads * head_dim;
2873        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2874    }
2875
2876    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2877    float angle = float(pos) * freq;
2878    float cos_val = cos(angle);
2879    float sin_val = sin(angle);
2880
2881    float x0 = qkv_data[offset];
2882    float x1 = qkv_data[offset + 1];
2883    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2884    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2885}
2886
2887// ── copy_kv_both_batch ────────────────────────────────────────────────
2888// Fused K+V cache copy in a single dispatch: copies both K and V from
2889// the strided batch QKV buffer to their respective KV cache buffers.
2890// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2891kernel void copy_kv_both_batch(
2892    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride]
2893    device float* k_dst        [[buffer(1)]],  // K cache [max_seq, kv_dim]
2894    device float* v_dst        [[buffer(2)]],  // V cache [max_seq, kv_dim]
2895    constant uint& M           [[buffer(3)]],  // num tokens in batch
2896    constant uint& kv_dim      [[buffer(4)]],  // floats per KV vector
2897    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2898    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2899    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2900    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2901    uint id [[thread_position_in_grid]])
2902{
2903    // Total elements = M * kv_dim * 2 (K + V)
2904    uint total_kv = M * kv_dim;
2905    if (id >= total_kv * 2) return;
2906
2907    uint is_v = id / total_kv;        // 0 = K, 1 = V
2908    uint local_id = id % total_kv;
2909    uint token = local_id / kv_dim;
2910    uint d = local_id % kv_dim;
2911
2912    uint dst_off = (base_pos + token) * kv_dim + d;
2913    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2914
2915    if (is_v) {
2916        v_dst[dst_off] = src[src_off];
2917    } else {
2918        k_dst[dst_off] = src[src_off];
2919    }
2920}
2921"#
2922    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2923    .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2924}
2925
2926// ---------------------------------------------------------------------------
2927// model.rs generation
2928// ---------------------------------------------------------------------------
2929
2930fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2931    let mut code = String::with_capacity(48 * 1024);
2932    emit_model_header(&mut code, config)?;
2933    emit_metal_model_struct(&mut code, config)?;
2934    emit_layer_buffers_struct(&mut code, config)?;
2935    emit_metal_model_impl(&mut code, config)?;
2936    emit_helper_functions(&mut code)?;
2937    Ok(code)
2938}
2939
2940fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2941    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2942    writeln!(
2943        code,
2944        "//! Model: {} ({} layers, hidden={})",
2945        config.architecture, config.num_layers, config.hidden_size
2946    )?;
2947    writeln!(code, "//!")?;
2948    writeln!(
2949        code,
2950        "//! Uses native Metal compute pipelines via the metal crate."
2951    )?;
2952    writeln!(code)?;
2953    writeln!(code, "#![allow(dead_code)]")?;
2954    writeln!(code)?;
2955    writeln!(code, "use metal::*;")?;
2956    writeln!(code, "#[allow(unused_imports)]")?;
2957    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2958    writeln!(code, "use std::mem;")?;
2959    writeln!(code)?;
2960
2961    // Model constants
2962    writeln!(
2963        code,
2964        "// ── Model constants ──────────────────────────────────"
2965    )?;
2966    writeln!(
2967        code,
2968        "pub const HIDDEN_SIZE: usize = {};",
2969        config.hidden_size
2970    )?;
2971    writeln!(
2972        code,
2973        "pub const INTERMEDIATE_SIZE: usize = {};",
2974        config.intermediate_size
2975    )?;
2976    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2977    writeln!(
2978        code,
2979        "pub const NUM_HEADS: usize = {};",
2980        config.num_attention_heads
2981    )?;
2982    writeln!(
2983        code,
2984        "pub const NUM_KV_HEADS: usize = {};",
2985        config.num_kv_heads
2986    )?;
2987    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
2988    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
2989    let effective_seq_len = config.max_seq_len.min(4096);
2990    writeln!(
2991        code,
2992        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
2993        effective_seq_len, config.max_seq_len
2994    )?;
2995    writeln!(
2996        code,
2997        "pub const RMS_NORM_EPS: f32 = {:e};",
2998        config.rms_norm_eps
2999    )?;
3000    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3001    writeln!(
3002        code,
3003        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3004    )?;
3005    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3006    writeln!(code)?;
3007
3008    Ok(())
3009}
3010
3011fn emit_metal_model_struct(
3012    code: &mut String,
3013    config: &ModelConfig,
3014) -> Result<(), MetalCodegenError> {
3015    writeln!(
3016        code,
3017        "// ── MetalModel ──────────────────────────────────────────"
3018    )?;
3019    writeln!(code)?;
3020    writeln!(
3021        code,
3022        "/// Metal-accelerated transformer model for Apple Silicon."
3023    )?;
3024    writeln!(code, "///")?;
3025    writeln!(
3026        code,
3027        "/// Uses unified memory for zero-copy weight access and native Metal"
3028    )?;
3029    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3030    writeln!(code, "pub struct MetalModel {{")?;
3031    writeln!(code, "    device: Device,")?;
3032    writeln!(code, "    queue: CommandQueue,")?;
3033    writeln!(code)?;
3034    writeln!(code, "    // ── Compute pipelines ──")?;
3035    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
3036    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
3037    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
3038    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
3039    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
3040    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
3041    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
3042    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
3043    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
3044    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
3045    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
3046    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
3047    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
3048    writeln!(code)?;
3049    writeln!(code, "    // ── Batched prefill pipelines ──")?;
3050    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
3051    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
3052    writeln!(
3053        code,
3054        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3055    )?;
3056    writeln!(code, "    matmul_q8_mma_pipeline: ComputePipelineState,")?;
3057    writeln!(code, "    matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3058    writeln!(
3059        code,
3060        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3061    )?;
3062    writeln!(
3063        code,
3064        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3065    )?;
3066    writeln!(
3067        code,
3068        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3069    )?;
3070    if config.qkv_bias {
3071        writeln!(code, "    add_bias_batch_pipeline: ComputePipelineState,")?;
3072    }
3073    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
3074    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
3075    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
3076    writeln!(
3077        code,
3078        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
3079    )?;
3080    writeln!(
3081        code,
3082        "    add_inplace_batch_pipeline: ComputePipelineState,"
3083    )?;
3084    writeln!(
3085        code,
3086        "    copy_embedding_batch_pipeline: ComputePipelineState,"
3087    )?;
3088    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
3089    writeln!(
3090        code,
3091        "    attention_flash_batch_pipeline: ComputePipelineState,"
3092    )?;
3093    writeln!(
3094        code,
3095        "    attention_mma_flash_batch_pipeline: ComputePipelineState,"
3096    )?;
3097    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
3098    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
3099    writeln!(
3100        code,
3101        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
3102    )?;
3103    writeln!(code)?;
3104    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
3105    writeln!(
3106        code,
3107        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3108    )?;
3109    writeln!(code, "    embed_buf: Buffer,")?;
3110    writeln!(code)?;
3111    writeln!(code, "    /// Per-layer weight buffers")?;
3112    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
3113    writeln!(code)?;
3114    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
3115    writeln!(code, "    norm_buf: Buffer,")?;
3116    writeln!(code)?;
3117    writeln!(
3118        code,
3119        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3120    )?;
3121    writeln!(code, "    lm_head_buf: Buffer,")?;
3122    writeln!(code)?;
3123    writeln!(
3124        code,
3125        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
3126    )?;
3127    writeln!(code, "    hidden_buf: Buffer,")?;
3128    writeln!(code, "    residual_buf: Buffer,")?;
3129    writeln!(code, "    normed_buf: Buffer,")?;
3130    writeln!(
3131        code,
3132        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3133    )?;
3134    writeln!(code, "    qkv_buf: Buffer,")?;
3135    writeln!(code, "    attn_out_buf: Buffer,")?;
3136    writeln!(code, "    attn_proj_buf: Buffer,")?;
3137    writeln!(
3138        code,
3139        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3140    )?;
3141    writeln!(code, "    gate_up_buf: Buffer,")?;
3142    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
3143    writeln!(code, "    ffn_out_buf: Buffer,")?;
3144    writeln!(code, "    add_tmp_buf: Buffer,")?;
3145    writeln!(code, "    logits_buf: Buffer,")?;
3146    writeln!(code)?;
3147    writeln!(code, "    // ── Batched prefill working buffers ──")?;
3148    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3149    writeln!(code, "    batch_hidden_buf: Buffer,")?;
3150    writeln!(
3151        code,
3152        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3153    )?;
3154    writeln!(code, "    batch_residual_buf: Buffer,")?;
3155    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3156    writeln!(code, "    batch_qkv_buf: Buffer,")?;
3157    writeln!(
3158        code,
3159        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3160    )?;
3161    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
3162    writeln!(
3163        code,
3164        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3165    )?;
3166    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
3167    writeln!(
3168        code,
3169        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3170    )?;
3171    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
3172    writeln!(
3173        code,
3174        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3175    )?;
3176    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
3177    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3178    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
3179    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
3180    writeln!(code, "    batch_tokens_buf: Buffer,")?;
3181    writeln!(code, "    /// Positions buffer for batch RoPE")?;
3182    writeln!(code, "    batch_positions_buf: Buffer,")?;
3183    writeln!(code)?;
3184    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
3185    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
3186    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
3187    writeln!(code)?;
3188    writeln!(code, "    // ── Inference state ──")?;
3189    writeln!(code, "    pos: usize,")?;
3190    writeln!(code)?;
3191    writeln!(
3192        code,
3193        "    /// Previous command buffer for double-buffered prefill."
3194    )?;
3195    writeln!(
3196        code,
3197        "    /// While the GPU executes token N, the CPU can encode token N+1."
3198    )?;
3199    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
3200    writeln!(code, "}}")?;
3201    writeln!(code)?;
3202
3203    Ok(())
3204}
3205
3206fn emit_layer_buffers_struct(
3207    code: &mut String,
3208    config: &ModelConfig,
3209) -> Result<(), MetalCodegenError> {
3210    writeln!(
3211        code,
3212        "/// Per-layer weight buffers for attention and FFN projections."
3213    )?;
3214    writeln!(code, "struct LayerBuffers {{")?;
3215    writeln!(code, "    attn_norm: Buffer,")?;
3216    writeln!(
3217        code,
3218        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3219    )?;
3220    writeln!(code, "    qkv_weight: Buffer,")?;
3221    if config.qkv_bias {
3222        writeln!(
3223            code,
3224            "    /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3225        )?;
3226        writeln!(code, "    qkv_bias: Buffer,")?;
3227    }
3228    writeln!(code, "    o_weight: Buffer,")?;
3229    writeln!(code, "    ffn_norm: Buffer,")?;
3230    writeln!(
3231        code,
3232        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3233    )?;
3234    writeln!(code, "    gate_up_weight: Buffer,")?;
3235    writeln!(code, "    down_weight: Buffer,")?;
3236    writeln!(code, "}}")?;
3237    writeln!(code)?;
3238
3239    Ok(())
3240}
3241
3242fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3243    let hidden = config.hidden_size;
3244    let intermediate = config.intermediate_size;
3245    let _num_layers = config.num_layers;
3246    let num_heads = config.num_attention_heads;
3247    let num_kv_heads = config.num_kv_heads;
3248    let head_dim = config.head_dim;
3249    let vocab = config.vocab_size;
3250    let effective_seq_len = config.max_seq_len.min(4096);
3251    let is_q8 = config.dtype == DType::Q8_0;
3252    let is_q4 = config.dtype == DType::Q4_0;
3253    let kv_dim = num_kv_heads * head_dim;
3254
3255    writeln!(code, "impl MetalModel {{")?;
3256
3257    // ── new() ──
3258    writeln!(
3259        code,
3260        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3261    )?;
3262    writeln!(code, "    ///")?;
3263    writeln!(
3264        code,
3265        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
3266    )?;
3267    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
3268    writeln!(
3269        code,
3270        "        let device = Device::system_default().expect(\"no Metal device found\");"
3271    )?;
3272    writeln!(code, "        let queue = device.new_command_queue();")?;
3273    writeln!(code)?;
3274
3275    // Compile shaders
3276    writeln!(
3277        code,
3278        "        // Compile Metal shaders from embedded source"
3279    )?;
3280    writeln!(
3281        code,
3282        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
3283    )?;
3284    writeln!(
3285        code,
3286        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3287    )?;
3288    writeln!(
3289        code,
3290        "            .expect(\"failed to compile Metal shaders\");"
3291    )?;
3292    writeln!(code)?;
3293
3294    // Create compute pipelines
3295    writeln!(code, "        // Create compute pipelines")?;
3296    for (var, fn_name) in [
3297        ("matmul_pipeline", "matmul_vec"),
3298        ("matmul_q8_pipeline", "matmul_vec_q8"),
3299        ("matmul_q4_pipeline", "matmul_vec_q4"),
3300        ("rms_norm_pipeline", "rms_norm"),
3301        ("rope_pipeline", "rope"),
3302        ("softmax_pipeline", "softmax"),
3303        ("silu_mul_pipeline", "silu_mul"),
3304        ("silu_mul_fused_pipeline", "silu_mul_fused"),
3305        ("add_pipeline", "elementwise_add"),
3306        ("attention_pipeline", "attention"),
3307        ("add_inplace_pipeline", "add_inplace"),
3308        ("copy_pipeline", "copy_buffer"),
3309        ("copy_offset_pipeline", "copy_offset"),
3310        ("matmul_batch_pipeline", "matmul_vec_batch"),
3311        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3312        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3313        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3314        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3315        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3316        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3317        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3318        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3319        ("rms_norm_batch_pipeline", "rms_norm_batch"),
3320        ("rope_batch_pipeline", "rope_batch"),
3321        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3322        ("add_inplace_batch_pipeline", "add_inplace_batch"),
3323        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3324        ("attention_batch_pipeline", "attention_batch"),
3325        ("attention_flash_batch_pipeline", "attention_flash_batch"),
3326        (
3327            "attention_mma_flash_batch_pipeline",
3328            "attention_mma_flash_batch",
3329        ),
3330        ("copy_kv_batch_pipeline", "copy_kv_batch"),
3331        ("rope_qk_batch_pipeline", "rope_qk_batch"),
3332        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3333    ] {
3334        writeln!(
3335            code,
3336            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3337        )?;
3338    }
3339    if config.qkv_bias {
3340        writeln!(
3341            code,
3342            "        let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3343        )?;
3344    }
3345    writeln!(code)?;
3346
3347    // Weight loading
3348    writeln!(
3349        code,
3350        "        // Load weights into Metal shared-memory buffers"
3351    )?;
3352    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
3353    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3354    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
3355    writeln!(code)?;
3356    writeln!(
3357        code,
3358        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
3359    )?;
3360    writeln!(code)?;
3361    writeln!(
3362        code,
3363        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3364    )?;
3365    writeln!(
3366        code,
3367        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3368    )?;
3369    writeln!(code, "            let byte_len = n * f32_size;")?;
3370    writeln!(code, "            let cur = cursor.get();")?;
3371    writeln!(
3372        code,
3373        "            let data = &weights[cur..cur + byte_len];"
3374    )?;
3375    writeln!(code, "            cursor.set(cur + byte_len);")?;
3376    writeln!(code, "            device.new_buffer_with_data(")?;
3377    writeln!(code, "                data.as_ptr() as *const _,")?;
3378    writeln!(code, "                byte_len as u64,")?;
3379    writeln!(
3380        code,
3381        "                MTLResourceOptions::StorageModeShared,"
3382    )?;
3383    writeln!(code, "            )")?;
3384    writeln!(code, "        }};")?;
3385    writeln!(code)?;
3386
3387    if is_q8 {
3388        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
3389        // We load them directly into Metal buffers without dequantizing,
3390        // and use the matmul_vec_q8 shader that operates on quantized data.
3391        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
3392        writeln!(
3393            code,
3394            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3395        )?;
3396        writeln!(
3397            code,
3398            "        // as raw bytes into a Metal buffer (no dequantization)."
3399        )?;
3400        writeln!(
3401            code,
3402            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3403        )?;
3404        writeln!(
3405            code,
3406            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3407        )?;
3408        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3409        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3410        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3411        writeln!(code, "            let cur = cursor.get();")?;
3412        writeln!(
3413            code,
3414            "            let data = &weights[cur..cur + total_raw];"
3415        )?;
3416        writeln!(code, "            cursor.set(cur + total_raw);")?;
3417        writeln!(code, "            device.new_buffer_with_data(")?;
3418        writeln!(code, "                data.as_ptr() as *const _,")?;
3419        writeln!(code, "                total_raw as u64,")?;
3420        writeln!(
3421            code,
3422            "                MTLResourceOptions::StorageModeShared,"
3423        )?;
3424        writeln!(code, "            )")?;
3425        writeln!(code, "        }};")?;
3426        writeln!(code)?;
3427        writeln!(
3428            code,
3429            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3430        )?;
3431        writeln!(
3432            code,
3433            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3434        )?;
3435        writeln!(
3436            code,
3437            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3438        )?;
3439        writeln!(
3440            code,
3441            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3442        )?;
3443        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3444        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3445        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3446        writeln!(code, "            let cur = cursor.get();")?;
3447        writeln!(
3448            code,
3449            "            let data = &weights[cur..cur + total_raw];"
3450        )?;
3451        writeln!(code, "            cursor.set(cur + total_raw);")?;
3452        writeln!(code, "            device.new_buffer_with_data(")?;
3453        writeln!(code, "                data.as_ptr() as *const _,")?;
3454        writeln!(code, "                total_raw as u64,")?;
3455        writeln!(
3456            code,
3457            "                MTLResourceOptions::StorageModeShared,"
3458        )?;
3459        writeln!(code, "            )")?;
3460        writeln!(code, "        }};")?;
3461        writeln!(code)?;
3462    }
3463
3464    if is_q4 {
3465        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3466        // We load them directly into Metal buffers without dequantizing,
3467        // and use the matmul_vec_q4 shader that operates on quantized data.
3468        // This quarters GPU memory usage vs f32 dequantization.
3469        writeln!(
3470            code,
3471            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3472        )?;
3473        writeln!(
3474            code,
3475            "        // as raw bytes into a Metal buffer (no dequantization)."
3476        )?;
3477        writeln!(
3478            code,
3479            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3480        )?;
3481        writeln!(
3482            code,
3483            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3484        )?;
3485        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3486        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3487        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3488        writeln!(code, "            let cur = cursor.get();")?;
3489        writeln!(
3490            code,
3491            "            let data = &weights[cur..cur + total_raw];"
3492        )?;
3493        writeln!(code, "            cursor.set(cur + total_raw);")?;
3494        writeln!(code, "            device.new_buffer_with_data(")?;
3495        writeln!(code, "                data.as_ptr() as *const _,")?;
3496        writeln!(code, "                total_raw as u64,")?;
3497        writeln!(
3498            code,
3499            "                MTLResourceOptions::StorageModeShared,"
3500        )?;
3501        writeln!(code, "            )")?;
3502        writeln!(code, "        }};")?;
3503        writeln!(code)?;
3504        writeln!(
3505            code,
3506            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3507        )?;
3508        writeln!(
3509            code,
3510            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3511        )?;
3512        writeln!(
3513            code,
3514            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3515        )?;
3516        writeln!(
3517            code,
3518            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3519        )?;
3520        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3521        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3522        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3523        writeln!(code, "            let cur = cursor.get();")?;
3524        writeln!(
3525            code,
3526            "            let data = &weights[cur..cur + total_raw];"
3527        )?;
3528        writeln!(code, "            cursor.set(cur + total_raw);")?;
3529        writeln!(code, "            device.new_buffer_with_data(")?;
3530        writeln!(code, "                data.as_ptr() as *const _,")?;
3531        writeln!(code, "                total_raw as u64,")?;
3532        writeln!(
3533            code,
3534            "                MTLResourceOptions::StorageModeShared,"
3535        )?;
3536        writeln!(code, "            )")?;
3537        writeln!(code, "        }};")?;
3538        writeln!(code)?;
3539    }
3540
3541    writeln!(
3542        code,
3543        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3544    )?;
3545    writeln!(code)?;
3546
3547    // Per-layer weights
3548    writeln!(
3549        code,
3550        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3551    )?;
3552    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3553
3554    // attn_norm is always f32
3555    writeln!(
3556        code,
3557        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3558    )?;
3559
3560    let qkv_rows = hidden + 2 * kv_dim;
3561    if is_q8 {
3562        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3563        writeln!(
3564            code,
3565            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3566        )?;
3567        if config.qkv_bias {
3568            writeln!(
3569                code,
3570                "            // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3571            )?;
3572            writeln!(
3573                code,
3574                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3575            )?;
3576        }
3577        writeln!(
3578            code,
3579            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3580        )?;
3581    } else if is_q4 {
3582        writeln!(
3583            code,
3584            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3585        )?;
3586        if config.qkv_bias {
3587            writeln!(
3588                code,
3589                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3590            )?;
3591        }
3592        writeln!(
3593            code,
3594            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3595        )?;
3596    } else {
3597        writeln!(
3598            code,
3599            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3600        )?;
3601        if config.qkv_bias {
3602            writeln!(
3603                code,
3604                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3605            )?;
3606        }
3607        writeln!(
3608            code,
3609            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3610        )?;
3611    }
3612
3613    // ffn_norm is always f32
3614    writeln!(
3615        code,
3616        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3617    )?;
3618
3619    let gate_up_rows = 2 * intermediate;
3620    if is_q8 {
3621        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3622        writeln!(
3623            code,
3624            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3625        )?;
3626        writeln!(
3627            code,
3628            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3629        )?;
3630    } else if is_q4 {
3631        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3632        writeln!(
3633            code,
3634            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3635        )?;
3636        writeln!(
3637            code,
3638            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3639        )?;
3640    } else {
3641        // Fused gate+up weight: read both as a single contiguous f32 buffer
3642        writeln!(
3643            code,
3644            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3645        )?;
3646        writeln!(
3647            code,
3648            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3649        )?;
3650    }
3651
3652    writeln!(code, "            layers.push(LayerBuffers {{")?;
3653    writeln!(code, "                attn_norm,")?;
3654    writeln!(code, "                qkv_weight,")?;
3655    if config.qkv_bias {
3656        writeln!(code, "                qkv_bias,")?;
3657    }
3658    writeln!(code, "                o_weight,")?;
3659    writeln!(code, "                ffn_norm,")?;
3660    writeln!(code, "                gate_up_weight,")?;
3661    writeln!(code, "                down_weight,")?;
3662    writeln!(code, "            }});")?;
3663    writeln!(code, "        }}")?;
3664    writeln!(code)?;
3665
3666    // final_norm is always f32
3667    writeln!(
3668        code,
3669        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3670    )?;
3671    writeln!(code)?;
3672
3673    // lm_head
3674    if is_q8 {
3675        writeln!(
3676            code,
3677            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3678        )?;
3679    } else if is_q4 {
3680        writeln!(
3681            code,
3682            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3683        )?;
3684    } else {
3685        writeln!(
3686            code,
3687            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3688        )?;
3689    }
3690    writeln!(code)?;
3691
3692    // Working buffers
3693    let hidden_bytes = hidden * 4;
3694    let _kv_dim_bytes = kv_dim * 4;
3695    let intermediate_bytes = intermediate * 4;
3696    let vocab_bytes = vocab * 4;
3697    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
3698
3699    writeln!(
3700        code,
3701        "        // Allocate working buffers (shared memory for zero-copy)"
3702    )?;
3703    writeln!(
3704        code,
3705        "        let opts = MTLResourceOptions::StorageModeShared;"
3706    )?;
3707    writeln!(
3708        code,
3709        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3710    )?;
3711    writeln!(
3712        code,
3713        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3714    )?;
3715    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3716    writeln!(
3717        code,
3718        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3719    )?;
3720    writeln!(
3721        code,
3722        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3723    )?;
3724    writeln!(
3725        code,
3726        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3727    )?;
3728    writeln!(
3729        code,
3730        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3731    )?;
3732    writeln!(
3733        code,
3734        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3735    )?;
3736    let gate_up_buf_bytes = 2 * intermediate * 4;
3737    writeln!(
3738        code,
3739        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3740    )?;
3741    writeln!(
3742        code,
3743        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3744    )?;
3745    writeln!(
3746        code,
3747        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3748    )?;
3749    writeln!(
3750        code,
3751        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3752    )?;
3753    writeln!(
3754        code,
3755        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3756    )?;
3757    writeln!(
3758        code,
3759        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3760    )?;
3761    writeln!(code)?;
3762
3763    // Batch prefill working buffers
3764    let batch_hidden_bytes = hidden * 4; // per-token
3765    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3766    let batch_gate_up_bytes = 2 * intermediate * 4;
3767    let batch_intermediate_bytes = intermediate * 4;
3768    writeln!(
3769        code,
3770        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3771    )?;
3772    writeln!(
3773        code,
3774        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3775    )?;
3776    writeln!(
3777        code,
3778        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3779    )?;
3780    writeln!(
3781        code,
3782        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3783    )?;
3784    writeln!(
3785        code,
3786        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3787    )?;
3788    writeln!(
3789        code,
3790        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3791    )?;
3792    writeln!(
3793        code,
3794        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3795    )?;
3796    writeln!(
3797        code,
3798        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3799    )?;
3800    writeln!(
3801        code,
3802        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3803    )?;
3804    writeln!(
3805        code,
3806        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3807    )?;
3808    writeln!(
3809        code,
3810        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3811    )?;
3812    writeln!(code)?;
3813
3814    // KV cache buffers
3815    writeln!(code, "        // KV cache buffers (per-layer)")?;
3816    writeln!(
3817        code,
3818        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3819    )?;
3820    writeln!(
3821        code,
3822        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3823    )?;
3824    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3825    writeln!(
3826        code,
3827        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3828    )?;
3829    writeln!(
3830        code,
3831        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3832    )?;
3833    writeln!(code, "        }}")?;
3834    writeln!(code)?;
3835
3836    writeln!(code, "        Self {{")?;
3837    writeln!(code, "            device,")?;
3838    writeln!(code, "            queue,")?;
3839    writeln!(code, "            matmul_pipeline,")?;
3840    writeln!(code, "            matmul_q8_pipeline,")?;
3841    writeln!(code, "            matmul_q4_pipeline,")?;
3842    writeln!(code, "            rms_norm_pipeline,")?;
3843    writeln!(code, "            rope_pipeline,")?;
3844    writeln!(code, "            softmax_pipeline,")?;
3845    writeln!(code, "            silu_mul_pipeline,")?;
3846    writeln!(code, "            silu_mul_fused_pipeline,")?;
3847    writeln!(code, "            add_pipeline,")?;
3848    writeln!(code, "            attention_pipeline,")?;
3849    writeln!(code, "            add_inplace_pipeline,")?;
3850    writeln!(code, "            copy_pipeline,")?;
3851    writeln!(code, "            copy_offset_pipeline,")?;
3852    writeln!(code, "            matmul_batch_pipeline,")?;
3853    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3854    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3855    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3856    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3857    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3858    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3859    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3860    if config.qkv_bias {
3861        writeln!(code, "            add_bias_batch_pipeline,")?;
3862    }
3863    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3864    writeln!(code, "            rms_norm_batch_pipeline,")?;
3865    writeln!(code, "            rope_batch_pipeline,")?;
3866    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3867    writeln!(code, "            add_inplace_batch_pipeline,")?;
3868    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3869    writeln!(code, "            attention_batch_pipeline,")?;
3870    writeln!(code, "            attention_flash_batch_pipeline,")?;
3871    writeln!(code, "            attention_mma_flash_batch_pipeline,")?;
3872    writeln!(code, "            copy_kv_batch_pipeline,")?;
3873    writeln!(code, "            rope_qk_batch_pipeline,")?;
3874    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3875    writeln!(code, "            embed_buf,")?;
3876    writeln!(code, "            layers,")?;
3877    writeln!(code, "            norm_buf,")?;
3878    writeln!(code, "            lm_head_buf,")?;
3879    writeln!(code, "            hidden_buf,")?;
3880    writeln!(code, "            residual_buf,")?;
3881    writeln!(code, "            normed_buf,")?;
3882    writeln!(code, "            qkv_buf,")?;
3883    writeln!(code, "            attn_out_buf,")?;
3884    writeln!(code, "            attn_proj_buf,")?;
3885    writeln!(code, "            gate_up_buf,")?;
3886    writeln!(code, "            ffn_hidden_buf,")?;
3887    writeln!(code, "            ffn_out_buf,")?;
3888    writeln!(code, "            add_tmp_buf,")?;
3889    writeln!(code, "            logits_buf,")?;
3890    writeln!(code, "            batch_hidden_buf,")?;
3891    writeln!(code, "            batch_residual_buf,")?;
3892    writeln!(code, "            batch_qkv_buf,")?;
3893    writeln!(code, "            batch_attn_out_buf,")?;
3894    writeln!(code, "            batch_attn_proj_buf,")?;
3895    writeln!(code, "            batch_gate_up_buf,")?;
3896    writeln!(code, "            batch_ffn_hidden_buf,")?;
3897    writeln!(code, "            batch_ffn_out_buf,")?;
3898    writeln!(code, "            batch_tokens_buf,")?;
3899    writeln!(code, "            batch_positions_buf,")?;
3900    writeln!(code, "            k_cache,")?;
3901    writeln!(code, "            v_cache,")?;
3902    writeln!(code, "            pos: 0,")?;
3903    writeln!(code, "            prev_cmd: None,")?;
3904    writeln!(code, "        }}")?;
3905    writeln!(code, "    }}")?;
3906    writeln!(code)?;
3907
3908    // ── forward() ──
3909    writeln!(
3910        code,
3911        "    /// Run the forward pass for a single token at the current position."
3912    )?;
3913    writeln!(code, "    ///")?;
3914    writeln!(
3915        code,
3916        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3917    )?;
3918    writeln!(code, "    ///")?;
3919    writeln!(
3920        code,
3921        "    /// All GPU operations are encoded into a single command buffer and"
3922    )?;
3923    writeln!(
3924        code,
3925        "    /// committed once at the end, avoiding per-operation synchronization."
3926    )?;
3927    writeln!(
3928        code,
3929        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3930    )?;
3931    writeln!(
3932        code,
3933        "        // Wait for any pending prefill command buffer"
3934    )?;
3935    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3936    writeln!(code, "            prev.wait_until_completed();")?;
3937    writeln!(code, "        }}")?;
3938    writeln!(code)?;
3939    writeln!(code, "        let pos = self.pos;")?;
3940    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3941    writeln!(code)?;
3942
3943    // Single compute encoder for the entire forward pass — no blit encoder
3944    // transitions. Copy operations use compute copy kernels instead of blits.
3945    let matmul_fn = if is_q8 {
3946        "dispatch_matmul_q8"
3947    } else if is_q4 {
3948        "dispatch_matmul_q4"
3949    } else {
3950        "dispatch_matmul"
3951    };
3952
3953    writeln!(
3954        code,
3955        "        // Single compute encoder for the entire forward pass (no blit transitions)"
3956    )?;
3957    writeln!(code, "        {{")?;
3958    writeln!(
3959        code,
3960        "            let enc = cmd.new_compute_command_encoder();"
3961    )?;
3962    writeln!(code)?;
3963
3964    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
3965    writeln!(
3966        code,
3967        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3968    )?;
3969    writeln!(
3970        code,
3971        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3972    )?;
3973    writeln!(
3974        code,
3975        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3976    )?;
3977    writeln!(
3978        code,
3979        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
3980        hidden * 4,
3981    )?;
3982    writeln!(code, "            unsafe {{")?;
3983    writeln!(
3984        code,
3985        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
3986    )?;
3987    writeln!(
3988        code,
3989        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3990    )?;
3991    writeln!(
3992        code,
3993        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
3994    )?;
3995    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
3996    writeln!(
3997        code,
3998        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3999    )?;
4000    writeln!(code, "                    hidden_ptr,")?;
4001    writeln!(code, "                    HIDDEN_SIZE,")?;
4002    writeln!(code, "                );")?;
4003    writeln!(
4004        code,
4005        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4006    )?;
4007    writeln!(code, "            }}")?;
4008    writeln!(code)?;
4009
4010    // 2. Transformer layers
4011    writeln!(code, "            // 2. Transformer layers")?;
4012    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4013    writeln!(code)?;
4014    let q_byte_offset = 0usize;
4015    let k_byte_offset = hidden * 4;
4016    let v_byte_offset = (hidden + kv_dim) * 4;
4017
4018    writeln!(
4019        code,
4020        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
4021    )?;
4022    writeln!(
4023        code,
4024        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4025    )?;
4026    writeln!(
4027        code,
4028        "                // Fused Q+K+V matmul: single dispatch for all three projections"
4029    )?;
4030    writeln!(
4031        code,
4032        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4033    )?;
4034    if config.qkv_bias {
4035        writeln!(
4036            code,
4037            "                // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4038        )?;
4039        writeln!(
4040            code,
4041            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4042        )?;
4043    }
4044    writeln!(
4045        code,
4046        "                // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
4047    )?;
4048    writeln!(
4049        code,
4050        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4051    )?;
4052    writeln!(
4053        code,
4054        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4055    )?;
4056    writeln!(code)?;
4057    writeln!(
4058        code,
4059        "                // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
4060    )?;
4061    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
4062    writeln!(
4063        code,
4064        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4065    )?;
4066    writeln!(
4067        code,
4068        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4069    )?;
4070    writeln!(code)?;
4071    writeln!(
4072        code,
4073        "                // Attention using Q from qkv_buf (offset 0)"
4074    )?;
4075    writeln!(
4076        code,
4077        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4078    )?;
4079    writeln!(
4080        code,
4081        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4082    )?;
4083    writeln!(
4084        code,
4085        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4086    )?;
4087    writeln!(
4088        code,
4089        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4090    )?;
4091    writeln!(
4092        code,
4093        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4094    )?;
4095    writeln!(
4096        code,
4097        "                // Fused gate+up matmul: single dispatch for both projections"
4098    )?;
4099    writeln!(
4100        code,
4101        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4102    )?;
4103    writeln!(
4104        code,
4105        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4106    )?;
4107    writeln!(
4108        code,
4109        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4110    )?;
4111    writeln!(
4112        code,
4113        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4114    )?;
4115    writeln!(code, "            }}")?;
4116    writeln!(code)?;
4117
4118    // 3. Final RMS norm + logits
4119    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
4120    writeln!(
4121        code,
4122        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4123    )?;
4124    writeln!(
4125        code,
4126        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4127    )?;
4128    writeln!(code)?;
4129    writeln!(code, "            enc.end_encoding();")?;
4130    writeln!(code, "        }}")?;
4131    writeln!(code)?;
4132
4133    // 5. Single commit + wait, then read back logits
4134    writeln!(
4135        code,
4136        "        // 5. Commit all GPU work and wait for completion"
4137    )?;
4138    writeln!(code, "        cmd.commit();")?;
4139    writeln!(code, "        cmd.wait_until_completed();")?;
4140    writeln!(code)?;
4141    writeln!(code, "        // 6. Read back logits from GPU")?;
4142    writeln!(code, "        let logits = unsafe {{")?;
4143    writeln!(
4144        code,
4145        "            let ptr = self.logits_buf.contents() as *const f32;"
4146    )?;
4147    writeln!(
4148        code,
4149        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4150    )?;
4151    writeln!(code, "        }};")?;
4152    writeln!(code)?;
4153    writeln!(code, "        self.pos += 1;")?;
4154    writeln!(code, "        logits")?;
4155    writeln!(code, "    }}")?;
4156    writeln!(code)?;
4157
4158    // ── forward_profile: instrumented forward with per-operation timing ──
4159    writeln!(
4160        code,
4161        "    /// Profiling forward pass that prints per-stage GPU timing."
4162    )?;
4163    writeln!(code, "    ///")?;
4164    writeln!(
4165        code,
4166        "    /// Each stage is committed and waited on separately so that GPU timestamps"
4167    )?;
4168    writeln!(
4169        code,
4170        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
4171    )?;
4172    writeln!(
4173        code,
4174        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
4175    )?;
4176    writeln!(
4177        code,
4178        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4179    )?;
4180    writeln!(code, "        use std::time::Instant;")?;
4181    writeln!(code)?;
4182    writeln!(
4183        code,
4184        "        // Wait for any pending prefill command buffer"
4185    )?;
4186    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4187    writeln!(code, "            prev.wait_until_completed();")?;
4188    writeln!(code, "        }}")?;
4189    writeln!(code)?;
4190    writeln!(code, "        let pos = self.pos;")?;
4191    writeln!(code)?;
4192
4193    // Stage: embedding (CPU, no GPU)
4194    writeln!(
4195        code,
4196        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
4197    )?;
4198    writeln!(code, "        let t_embed = Instant::now();")?;
4199    writeln!(code, "        unsafe {{")?;
4200    writeln!(
4201        code,
4202        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
4203    )?;
4204    writeln!(
4205        code,
4206        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4207    )?;
4208    writeln!(
4209        code,
4210        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
4211    )?;
4212    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
4213    writeln!(
4214        code,
4215        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4216    )?;
4217    writeln!(code, "                hidden_ptr,")?;
4218    writeln!(code, "                HIDDEN_SIZE,")?;
4219    writeln!(code, "            );")?;
4220    writeln!(
4221        code,
4222        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4223    )?;
4224    writeln!(code, "        }}")?;
4225    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
4226    writeln!(code)?;
4227
4228    // Stage: Transformer layers (all together on GPU)
4229    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
4230    writeln!(code, "        let t_layers = Instant::now();")?;
4231    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4232    writeln!(code, "        {{")?;
4233    writeln!(
4234        code,
4235        "            let enc = cmd.new_compute_command_encoder();"
4236    )?;
4237    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4238    writeln!(
4239        code,
4240        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4241    )?;
4242    writeln!(
4243        code,
4244        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4245    )?;
4246    if config.qkv_bias {
4247        writeln!(
4248            code,
4249            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4250        )?;
4251    }
4252    writeln!(
4253        code,
4254        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4255    )?;
4256    writeln!(
4257        code,
4258        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4259    )?;
4260    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
4261    writeln!(
4262        code,
4263        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4264    )?;
4265    writeln!(
4266        code,
4267        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4268    )?;
4269    writeln!(
4270        code,
4271        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4272    )?;
4273    writeln!(
4274        code,
4275        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4276    )?;
4277    writeln!(
4278        code,
4279        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4280    )?;
4281    writeln!(
4282        code,
4283        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4284    )?;
4285    writeln!(
4286        code,
4287        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4288    )?;
4289    writeln!(
4290        code,
4291        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4292    )?;
4293    writeln!(
4294        code,
4295        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4296    )?;
4297    writeln!(
4298        code,
4299        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4300    )?;
4301    writeln!(code, "            }}")?;
4302    writeln!(code, "            enc.end_encoding();")?;
4303    writeln!(code, "        }}")?;
4304    writeln!(code, "        cmd.commit();")?;
4305    writeln!(code, "        cmd.wait_until_completed();")?;
4306    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
4307    writeln!(code)?;
4308
4309    // Stage: Final norm + logits
4310    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
4311    writeln!(code, "        let t_logits = Instant::now();")?;
4312    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
4313    writeln!(code, "        {{")?;
4314    writeln!(
4315        code,
4316        "            let enc = cmd2.new_compute_command_encoder();"
4317    )?;
4318    writeln!(
4319        code,
4320        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4321    )?;
4322    writeln!(
4323        code,
4324        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4325    )?;
4326    writeln!(code, "            enc.end_encoding();")?;
4327    writeln!(code, "        }}")?;
4328    writeln!(code, "        cmd2.commit();")?;
4329    writeln!(code, "        cmd2.wait_until_completed();")?;
4330    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
4331    writeln!(code)?;
4332
4333    // Print profile results
4334    writeln!(
4335        code,
4336        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4337    )?;
4338    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
4339    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
4340    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
4341    writeln!(
4342        code,
4343        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4344    )?;
4345    writeln!(code)?;
4346
4347    // Read back logits
4348    writeln!(code, "        let logits = unsafe {{")?;
4349    writeln!(
4350        code,
4351        "            let ptr = self.logits_buf.contents() as *const f32;"
4352    )?;
4353    writeln!(
4354        code,
4355        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4356    )?;
4357    writeln!(code, "        }};")?;
4358    writeln!(code)?;
4359    writeln!(code, "        self.pos += 1;")?;
4360    writeln!(code, "        logits")?;
4361    writeln!(code, "    }}")?;
4362    writeln!(code)?;
4363
4364    // ── forward_prefill: single-token async forward (backward compat) ──
4365    writeln!(
4366        code,
4367        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
4368    )?;
4369    writeln!(code, "    ///")?;
4370    writeln!(
4371        code,
4372        "    /// Commits the command buffer without waiting, enabling double-buffered"
4373    )?;
4374    writeln!(
4375        code,
4376        "    /// execution: GPU processes token N while CPU encodes token N+1."
4377    )?;
4378    writeln!(
4379        code,
4380        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
4381    )?;
4382    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
4383    writeln!(code, "    }}")?;
4384    writeln!(code)?;
4385
4386    // ── forward_prefill_batch: batched prefill for multiple tokens ──
4387    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
4388    let batch_matmul_fn = if is_q8 {
4389        "dispatch_matmul_q8_batch"
4390    } else if is_q4 {
4391        "dispatch_matmul_q4_batch"
4392    } else {
4393        "dispatch_matmul_batch"
4394    };
4395
4396    writeln!(
4397        code,
4398        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4399    )?;
4400    writeln!(code, "    ///")?;
4401    writeln!(
4402        code,
4403        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4404    )?;
4405    writeln!(
4406        code,
4407        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
4408    )?;
4409    writeln!(
4410        code,
4411        "    /// This provides significant speedup during prompt prefill."
4412    )?;
4413    writeln!(
4414        code,
4415        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4416    )?;
4417    writeln!(code, "        if tokens.is_empty() {{ return; }}")?;
4418    writeln!(
4419        code,
4420        "        // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4421    )?;
4422    writeln!(
4423        code,
4424        "        // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4425    )?;
4426    writeln!(
4427        code,
4428        "        // longer than that must be processed iteratively.  The KV cache"
4429    )?;
4430    writeln!(code, "        // carries state across chunks via self.pos.")?;
4431    writeln!(
4432        code,
4433        "        for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4434    )?;
4435    writeln!(code, "        let m = chunk.len();")?;
4436    writeln!(code, "        if m == 0 {{ continue; }}")?;
4437    writeln!(code, "        let start_pos = self.pos;")?;
4438    writeln!(code)?;
4439    writeln!(code, "        // Wait for any pending command buffer")?;
4440    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4441    writeln!(code, "            prev.wait_until_completed();")?;
4442    writeln!(code, "        }}")?;
4443    writeln!(code)?;
4444
4445    // Upload token IDs and positions to GPU
4446    writeln!(
4447        code,
4448        "        // Upload token IDs and positions to GPU buffers"
4449    )?;
4450    writeln!(code, "        unsafe {{")?;
4451    writeln!(
4452        code,
4453        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4454    )?;
4455    writeln!(
4456        code,
4457        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4458    )?;
4459    writeln!(code, "            for i in 0..m {{")?;
4460    writeln!(code, "                *tok_ptr.add(i) = chunk[i];")?;
4461    writeln!(
4462        code,
4463        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
4464    )?;
4465    writeln!(code, "            }}")?;
4466    writeln!(code, "        }}")?;
4467    writeln!(code)?;
4468
4469    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4470    writeln!(code, "        {{")?;
4471    writeln!(
4472        code,
4473        "            let enc = cmd.new_compute_command_encoder();"
4474    )?;
4475    writeln!(code)?;
4476
4477    // 1. Batch embedding lookup
4478    writeln!(
4479        code,
4480        "            // 1. Batch embedding lookup: copy all token embeddings at once"
4481    )?;
4482    writeln!(
4483        code,
4484        "            self.dispatch_copy_embedding_batch(&enc, m);"
4485    )?;
4486    // Copy batch_hidden -> batch_residual
4487    writeln!(
4488        code,
4489        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4490    )?;
4491    writeln!(code)?;
4492
4493    // 2. Transformer layers
4494    writeln!(code, "            // 2. Transformer layers")?;
4495    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4496    writeln!(code)?;
4497
4498    // Batch RMS norm: residual -> hidden (batched)
4499    writeln!(
4500        code,
4501        "                // Batch RMS norm: batch_residual -> batch_hidden"
4502    )?;
4503    writeln!(
4504        code,
4505        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4506    )?;
4507
4508    // Batch QKV matmul
4509    writeln!(
4510        code,
4511        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4512    )?;
4513    writeln!(
4514        code,
4515        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4516    )?;
4517    if config.qkv_bias {
4518        writeln!(
4519            code,
4520            "                // Qwen2: broadcast-add QKV bias across all M tokens."
4521        )?;
4522        writeln!(
4523            code,
4524            "                self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4525        )?;
4526    }
4527    writeln!(code)?;
4528
4529    // Fused RoPE on Q+K portions in a single dispatch
4530    let k_float_offset = hidden;
4531    writeln!(
4532        code,
4533        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4534    )?;
4535    writeln!(
4536        code,
4537        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4538    )?;
4539    writeln!(code)?;
4540
4541    // Fused KV cache update: copy both K and V in a single dispatch
4542    let v_float_offset = hidden + kv_dim;
4543    writeln!(
4544        code,
4545        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4546    )?;
4547    writeln!(
4548        code,
4549        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4550    )?;
4551    writeln!(code)?;
4552
4553    // Batched causal attention: ONE dispatch for all M tokens
4554    writeln!(
4555        code,
4556        "                // Batched causal attention: one dispatch for all M tokens"
4557    )?;
4558    writeln!(
4559        code,
4560        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4561    )?;
4562    writeln!(code)?;
4563
4564    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4565    writeln!(code, "                // Batched O projection")?;
4566    writeln!(
4567        code,
4568        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4569    )?;
4570    writeln!(code)?;
4571
4572    // Batch add: residual += attn_proj for all tokens
4573    writeln!(
4574        code,
4575        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4576    )?;
4577    writeln!(code)?;
4578
4579    // Batch FFN
4580    writeln!(
4581        code,
4582        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4583    )?;
4584    writeln!(
4585        code,
4586        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4587    )?;
4588    writeln!(
4589        code,
4590        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4591    )?;
4592    writeln!(
4593        code,
4594        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4595    )?;
4596    writeln!(
4597        code,
4598        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4599    )?;
4600    writeln!(
4601        code,
4602        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4603    )?;
4604    writeln!(code, "            }}")?;
4605    writeln!(code)?;
4606
4607    // Copy last token's residual to single-token residual_buf for next forward() call
4608    writeln!(
4609        code,
4610        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4611    )?;
4612    writeln!(
4613        code,
4614        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4615    )?;
4616    writeln!(code)?;
4617    writeln!(code, "            enc.end_encoding();")?;
4618    writeln!(code, "        }}")?;
4619    writeln!(code)?;
4620
4621    writeln!(code, "        cmd.commit();")?;
4622    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4623    writeln!(code, "        self.pos += m;")?;
4624    writeln!(code, "        }}  // end for chunk")?;
4625    writeln!(code, "    }}")?;
4626    writeln!(code)?;
4627
4628    // ── reset() — rewind KV cache position for new inference requests ──
4629    writeln!(
4630        code,
4631        "    /// Reset the model state for a new inference request."
4632    )?;
4633    writeln!(code, "    pub fn reset(&mut self) {{")?;
4634    writeln!(code, "        self.pos = 0;")?;
4635    writeln!(code, "        self.prev_cmd = None;")?;
4636    writeln!(code, "    }}")?;
4637    writeln!(code)?;
4638
4639    // ── Private dispatch helpers (all take a shared compute encoder) ──
4640    writeln!(
4641        code,
4642        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4643    )?;
4644    writeln!(
4645        code,
4646        "    // These methods set pipeline state + buffers + dispatch on an existing"
4647    )?;
4648    writeln!(
4649        code,
4650        "    // encoder, avoiding per-operation encoder creation overhead."
4651    )?;
4652    writeln!(code)?;
4653
4654    // dispatch_rms_norm
4655    writeln!(
4656        code,
4657        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4658    )?;
4659    writeln!(
4660        code,
4661        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4662    )?;
4663    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4664    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4665    writeln!(
4666        code,
4667        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4668    )?;
4669    writeln!(
4670        code,
4671        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4672    )?;
4673    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4674    writeln!(
4675        code,
4676        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4677    )?;
4678    writeln!(
4679        code,
4680        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4681    )?;
4682    writeln!(
4683        code,
4684        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4685    )?;
4686    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4687    writeln!(
4688        code,
4689        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4690    )?;
4691    writeln!(
4692        code,
4693        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4694    )?;
4695    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4696    writeln!(code, "    }}")?;
4697    writeln!(code)?;
4698
4699    // dispatch_matmul
4700    writeln!(
4701        code,
4702        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4703    )?;
4704    writeln!(
4705        code,
4706        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4707    )?;
4708    writeln!(code, "        let r: u32 = rows as u32;")?;
4709    writeln!(code, "        let c: u32 = cols as u32;")?;
4710    writeln!(
4711        code,
4712        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4713    )?;
4714    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4715    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4716    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4717    writeln!(
4718        code,
4719        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4720    )?;
4721    writeln!(
4722        code,
4723        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4724    )?;
4725    writeln!(
4726        code,
4727        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4728    )?;
4729    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4730    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4731    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4732    writeln!(
4733        code,
4734        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4735    )?;
4736    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4737    writeln!(code, "    }}")?;
4738    writeln!(code)?;
4739
4740    // dispatch_matmul_q8
4741    writeln!(
4742        code,
4743        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4744    )?;
4745    writeln!(
4746        code,
4747        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4748    )?;
4749    writeln!(
4750        code,
4751        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4752    )?;
4753    writeln!(code, "        let r: u32 = rows as u32;")?;
4754    writeln!(code, "        let c: u32 = cols as u32;")?;
4755    writeln!(
4756        code,
4757        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4758    )?;
4759    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4760    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4761    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4762    writeln!(
4763        code,
4764        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4765    )?;
4766    writeln!(
4767        code,
4768        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4769    )?;
4770    writeln!(
4771        code,
4772        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4773    )?;
4774    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4775    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4776    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4777    writeln!(
4778        code,
4779        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4780    )?;
4781    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4782    writeln!(code, "    }}")?;
4783    writeln!(code)?;
4784
4785    // dispatch_matmul_q4
4786    writeln!(
4787        code,
4788        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4789    )?;
4790    writeln!(
4791        code,
4792        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4793    )?;
4794    writeln!(
4795        code,
4796        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4797    )?;
4798    writeln!(code, "        let r: u32 = rows as u32;")?;
4799    writeln!(code, "        let c: u32 = cols as u32;")?;
4800    writeln!(
4801        code,
4802        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4803    )?;
4804    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4805    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4806    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4807    writeln!(
4808        code,
4809        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4810    )?;
4811    writeln!(
4812        code,
4813        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4814    )?;
4815    writeln!(
4816        code,
4817        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4818    )?;
4819    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4820    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4821    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4822    writeln!(
4823        code,
4824        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4825    )?;
4826    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4827    writeln!(code, "    }}")?;
4828    writeln!(code)?;
4829
4830    // dispatch_rope
4831    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4832    writeln!(
4833        code,
4834        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4835    )?;
4836    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4837    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4838    writeln!(code, "        let p: u32 = pos as u32;")?;
4839    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4840    writeln!(
4841        code,
4842        "        let total_pairs = num_heads * (head_dim / 2);"
4843    )?;
4844    writeln!(
4845        code,
4846        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4847    )?;
4848    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4849    writeln!(
4850        code,
4851        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4852    )?;
4853    writeln!(
4854        code,
4855        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4856    )?;
4857    writeln!(
4858        code,
4859        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4860    )?;
4861    writeln!(
4862        code,
4863        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4864    )?;
4865    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4866    writeln!(
4867        code,
4868        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4869    )?;
4870    writeln!(
4871        code,
4872        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4873    )?;
4874    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4875    writeln!(code, "    }}")?;
4876    writeln!(code)?;
4877
4878    // dispatch_rope_offset
4879    writeln!(
4880        code,
4881        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4882    )?;
4883    writeln!(
4884        code,
4885        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4886    )?;
4887    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4888    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4889    writeln!(code, "        let p: u32 = pos as u32;")?;
4890    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4891    writeln!(
4892        code,
4893        "        let total_pairs = num_heads * (head_dim / 2);"
4894    )?;
4895    writeln!(
4896        code,
4897        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4898    )?;
4899    writeln!(
4900        code,
4901        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4902    )?;
4903    writeln!(
4904        code,
4905        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4906    )?;
4907    writeln!(
4908        code,
4909        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4910    )?;
4911    writeln!(
4912        code,
4913        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4914    )?;
4915    writeln!(
4916        code,
4917        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4918    )?;
4919    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4920    writeln!(
4921        code,
4922        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4923    )?;
4924    writeln!(
4925        code,
4926        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4927    )?;
4928    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4929    writeln!(code, "    }}")?;
4930    writeln!(code)?;
4931
4932    // dispatch_attention
4933    writeln!(
4934        code,
4935        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4936    )?;
4937    writeln!(
4938        code,
4939        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4940    )?;
4941    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4942    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4943    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4944    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4945    writeln!(
4946        code,
4947        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4948    )?;
4949    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4950    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4951    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4952    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4953    writeln!(
4954        code,
4955        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4956    )?;
4957    writeln!(
4958        code,
4959        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4960    )?;
4961    writeln!(
4962        code,
4963        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4964    )?;
4965    writeln!(
4966        code,
4967        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4968    )?;
4969    writeln!(
4970        code,
4971        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4972    )?;
4973    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4974    writeln!(
4975        code,
4976        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4977    )?;
4978    writeln!(
4979        code,
4980        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4981    )?;
4982    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4983    writeln!(code, "    }}")?;
4984    writeln!(code)?;
4985
4986    // dispatch_attention_offset
4987    writeln!(
4988        code,
4989        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4990    )?;
4991    writeln!(
4992        code,
4993        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4994    )?;
4995    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4996    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4997    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4998    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4999    writeln!(
5000        code,
5001        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
5002    )?;
5003    writeln!(
5004        code,
5005        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5006    )?;
5007    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5008    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5009    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5010    writeln!(
5011        code,
5012        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5013    )?;
5014    writeln!(
5015        code,
5016        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5017    )?;
5018    writeln!(
5019        code,
5020        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5021    )?;
5022    writeln!(
5023        code,
5024        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5025    )?;
5026    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5027    writeln!(
5028        code,
5029        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5030    )?;
5031    writeln!(
5032        code,
5033        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5034    )?;
5035    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5036    writeln!(code, "    }}")?;
5037    writeln!(code)?;
5038
5039    // dispatch_silu_mul
5040    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
5041    writeln!(
5042        code,
5043        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5044    )?;
5045    writeln!(code, "        let count: u32 = n as u32;")?;
5046    writeln!(
5047        code,
5048        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5049    )?;
5050    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
5051    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
5052    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5053    writeln!(
5054        code,
5055        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5056    )?;
5057    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5058    writeln!(
5059        code,
5060        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5061    )?;
5062    writeln!(
5063        code,
5064        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5065    )?;
5066    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5067    writeln!(code, "    }}")?;
5068    writeln!(code)?;
5069
5070    // dispatch_silu_mul_fused
5071    writeln!(
5072        code,
5073        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5074    )?;
5075    writeln!(
5076        code,
5077        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
5078    )?;
5079    writeln!(
5080        code,
5081        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5082    )?;
5083    writeln!(code, "        let count: u32 = n as u32;")?;
5084    writeln!(
5085        code,
5086        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5087    )?;
5088    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5089    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5090    writeln!(
5091        code,
5092        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5093    )?;
5094    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5095    writeln!(
5096        code,
5097        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5098    )?;
5099    writeln!(
5100        code,
5101        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5102    )?;
5103    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5104    writeln!(code, "    }}")?;
5105    writeln!(code)?;
5106
5107    // dispatch_copy (simple src -> dst copy via compute kernel)
5108    writeln!(
5109        code,
5110        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5111    )?;
5112    writeln!(
5113        code,
5114        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5115    )?;
5116    writeln!(code, "        let n: u32 = count as u32;")?;
5117    writeln!(
5118        code,
5119        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5120    )?;
5121    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5122    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5123    writeln!(
5124        code,
5125        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5126    )?;
5127    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5128    writeln!(
5129        code,
5130        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5131    )?;
5132    writeln!(
5133        code,
5134        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5135    )?;
5136    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5137    writeln!(code, "    }}")?;
5138    writeln!(code)?;
5139
5140    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
5141    writeln!(
5142        code,
5143        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5144    )?;
5145    writeln!(
5146        code,
5147        "    /// Used for embedding table lookup (copy a specific row)."
5148    )?;
5149    writeln!(
5150        code,
5151        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5152    )?;
5153    writeln!(code, "        let off: u32 = src_offset as u32;")?;
5154    writeln!(code, "        let n: u32 = count as u32;")?;
5155    writeln!(
5156        code,
5157        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5158    )?;
5159    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5160    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5161    writeln!(
5162        code,
5163        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5164    )?;
5165    writeln!(
5166        code,
5167        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5168    )?;
5169    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5170    writeln!(
5171        code,
5172        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5173    )?;
5174    writeln!(
5175        code,
5176        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5177    )?;
5178    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5179    writeln!(code, "    }}")?;
5180    writeln!(code)?;
5181
5182    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
5183    writeln!(
5184        code,
5185        "    /// Dispatch copy from source at byte offset to destination at float offset."
5186    )?;
5187    writeln!(
5188        code,
5189        "    /// Used for KV cache updates from fused QKV buffer."
5190    )?;
5191    writeln!(
5192        code,
5193        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5194    )?;
5195    writeln!(code, "        let n: u32 = count as u32;")?;
5196    writeln!(
5197        code,
5198        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5199    )?;
5200    writeln!(
5201        code,
5202        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5203    )?;
5204    writeln!(
5205        code,
5206        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5207    )?;
5208    writeln!(
5209        code,
5210        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5211    )?;
5212    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5213    writeln!(
5214        code,
5215        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5216    )?;
5217    writeln!(
5218        code,
5219        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5220    )?;
5221    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5222    writeln!(code, "    }}")?;
5223    writeln!(code)?;
5224
5225    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
5226    writeln!(
5227        code,
5228        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5229    )?;
5230    writeln!(
5231        code,
5232        "    /// Used for KV cache updates (write to a specific position in the cache)."
5233    )?;
5234    writeln!(
5235        code,
5236        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5237    )?;
5238    writeln!(code, "        let n: u32 = count as u32;")?;
5239    writeln!(
5240        code,
5241        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5242    )?;
5243    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5244    writeln!(
5245        code,
5246        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5247    )?;
5248    writeln!(
5249        code,
5250        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5251    )?;
5252    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5253    writeln!(
5254        code,
5255        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5256    )?;
5257    writeln!(
5258        code,
5259        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5260    )?;
5261    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5262    writeln!(code, "    }}")?;
5263    writeln!(code)?;
5264
5265    // dispatch_add_inplace (residual connection, no blit needed)
5266    writeln!(
5267        code,
5268        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5269    )?;
5270    writeln!(
5271        code,
5272        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5273    )?;
5274    writeln!(code, "        let count: u32 = n as u32;")?;
5275    writeln!(
5276        code,
5277        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5278    )?;
5279    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5280    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5281    writeln!(
5282        code,
5283        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5284    )?;
5285    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5286    writeln!(
5287        code,
5288        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5289    )?;
5290    writeln!(
5291        code,
5292        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5293    )?;
5294    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5295    writeln!(code, "    }}")?;
5296    writeln!(code)?;
5297
5298    // ── Batched prefill dispatch helpers ──
5299    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
5300    writeln!(code)?;
5301
5302    // dispatch_copy_embedding_batch
5303    writeln!(
5304        code,
5305        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
5306    )?;
5307    writeln!(
5308        code,
5309        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5310    )?;
5311    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
5312    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5313    writeln!(
5314        code,
5315        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5316    )?;
5317    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5318    writeln!(
5319        code,
5320        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5321    )?;
5322    writeln!(
5323        code,
5324        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5325    )?;
5326    writeln!(
5327        code,
5328        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5329    )?;
5330    writeln!(
5331        code,
5332        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5333    )?;
5334    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
5335    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5336    writeln!(
5337        code,
5338        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5339    )?;
5340    writeln!(
5341        code,
5342        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5343    )?;
5344    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5345    writeln!(code, "    }}")?;
5346    writeln!(code)?;
5347
5348    // dispatch_rms_norm_batch
5349    writeln!(
5350        code,
5351        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
5352    )?;
5353    writeln!(
5354        code,
5355        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5356    )?;
5357    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
5358    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
5359    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5360    writeln!(
5361        code,
5362        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5363    )?;
5364    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
5365    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
5366    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5367    writeln!(
5368        code,
5369        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5370    )?;
5371    writeln!(
5372        code,
5373        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5374    )?;
5375    writeln!(
5376        code,
5377        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5378    )?;
5379    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5380    writeln!(
5381        code,
5382        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5383    )?;
5384    writeln!(
5385        code,
5386        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5387    )?;
5388    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5389    writeln!(code, "    }}")?;
5390    writeln!(code)?;
5391
5392    // dispatch_matmul_batch (f32)
5393    writeln!(
5394        code,
5395        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5396    )?;
5397    writeln!(
5398        code,
5399        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5400    )?;
5401    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5402    writeln!(code, "        let r: u32 = rows as u32;")?;
5403    writeln!(code, "        let c: u32 = cols as u32;")?;
5404    writeln!(
5405        code,
5406        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5407    )?;
5408    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5409    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5410    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5411    writeln!(
5412        code,
5413        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5414    )?;
5415    writeln!(
5416        code,
5417        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5418    )?;
5419    writeln!(
5420        code,
5421        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5422    )?;
5423    writeln!(
5424        code,
5425        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
5426    )?;
5427    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5428    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5429    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5430    writeln!(
5431        code,
5432        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5433    )?;
5434    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5435    writeln!(code, "    }}")?;
5436    writeln!(code)?;
5437
5438    // dispatch_matmul_q8_batch
5439    writeln!(
5440        code,
5441        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5442    )?;
5443    writeln!(code, "    ///")?;
5444    writeln!(
5445        code,
5446        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5447    )?;
5448    writeln!(
5449        code,
5450        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5451    )?;
5452    writeln!(
5453        code,
5454        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5455    )?;
5456    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5457    writeln!(code, "        let r: u32 = rows as u32;")?;
5458    writeln!(code, "        let c: u32 = cols as u32;")?;
5459    writeln!(
5460        code,
5461        "        // Tile sizes must match the Metal shader constants."
5462    )?;
5463    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
5464    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
5465    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
5466    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
5467    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
5468    writeln!(
5469        code,
5470        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
5471    )?;
5472    writeln!(
5473        code,
5474        "        // Prefer the large 32×32 tile when the problem supports it — halves"
5475    )?;
5476    writeln!(
5477        code,
5478        "        // dispatch count and reuses each weight load across 32 tokens."
5479    )?;
5480    writeln!(
5481        code,
5482        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5483    )?;
5484    writeln!(
5485        code,
5486        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5487    )?;
5488    writeln!(
5489        code,
5490        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
5491    )?;
5492    writeln!(
5493        code,
5494        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5495    )?;
5496    writeln!(
5497        code,
5498        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
5499    )?;
5500    writeln!(
5501        code,
5502        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5503    )?;
5504    writeln!(
5505        code,
5506        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5507    )?;
5508    writeln!(
5509        code,
5510        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5511    )?;
5512    writeln!(
5513        code,
5514        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5515    )?;
5516    writeln!(code, "            let use_h4 = cols >= 2048;")?;
5517    writeln!(code, "            let pipe = if use_h4 {{")?;
5518    writeln!(code, "                if num_tokens >= 256 {{")?;
5519    writeln!(
5520        code,
5521        "                    &self.matmul_q8_mma32_hh4_pipeline"
5522    )?;
5523    writeln!(code, "                }} else {{")?;
5524    writeln!(
5525        code,
5526        "                    &self.matmul_q8_mma32_h4_pipeline"
5527    )?;
5528    writeln!(code, "                }}")?;
5529    writeln!(code, "            }} else {{")?;
5530    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
5531    writeln!(code, "            }};")?;
5532    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
5533    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5534    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5535    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5536    writeln!(
5537        code,
5538        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5539    )?;
5540    writeln!(
5541        code,
5542        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5543    )?;
5544    writeln!(
5545        code,
5546        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5547    )?;
5548    writeln!(code, "            let row_tgs = rows / MMA32_ROW_TILE;")?;
5549    writeln!(
5550        code,
5551        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5552    )?;
5553    writeln!(
5554        code,
5555        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5556    )?;
5557    writeln!(
5558        code,
5559        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5560    )?;
5561    writeln!(
5562        code,
5563        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5564    )?;
5565    writeln!(
5566        code,
5567        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5568    )?;
5569    writeln!(
5570        code,
5571        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5572    )?;
5573    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5574    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5575    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5576    writeln!(
5577        code,
5578        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5579    )?;
5580    writeln!(
5581        code,
5582        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5583    )?;
5584    writeln!(
5585        code,
5586        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5587    )?;
5588    writeln!(code, "            let row_tgs = rows / MMA_ROW_TILE;")?;
5589    writeln!(
5590        code,
5591        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5592    )?;
5593    writeln!(
5594        code,
5595        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5596    )?;
5597    writeln!(
5598        code,
5599        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5600    )?;
5601    writeln!(
5602        code,
5603        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5604    )?;
5605    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5606    writeln!(
5607        code,
5608        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5609    )?;
5610    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5611    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5612    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5613    writeln!(
5614        code,
5615        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5616    )?;
5617    writeln!(
5618        code,
5619        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5620    )?;
5621    writeln!(
5622        code,
5623        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5624    )?;
5625    writeln!(
5626        code,
5627        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5628    )?;
5629    writeln!(
5630        code,
5631        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5632    )?;
5633    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5634    writeln!(
5635        code,
5636        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5637    )?;
5638    writeln!(
5639        code,
5640        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5641    )?;
5642    writeln!(code, "        }} else {{")?;
5643    writeln!(
5644        code,
5645        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5646    )?;
5647    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5648    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5649    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5650    writeln!(
5651        code,
5652        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5653    )?;
5654    writeln!(
5655        code,
5656        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5657    )?;
5658    writeln!(
5659        code,
5660        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5661    )?;
5662    writeln!(
5663        code,
5664        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5665    )?;
5666    writeln!(
5667        code,
5668        "            let num_tg = (row_tgs * num_tokens) as u64;"
5669    )?;
5670    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5671    writeln!(
5672        code,
5673        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5674    )?;
5675    writeln!(
5676        code,
5677        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5678    )?;
5679    writeln!(code, "        }}")?;
5680    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5681    writeln!(code, "    }}")?;
5682    writeln!(code)?;
5683
5684    // dispatch_matmul_q4_batch
5685    writeln!(
5686        code,
5687        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5688    )?;
5689    writeln!(
5690        code,
5691        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5692    )?;
5693    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5694    writeln!(code, "        let r: u32 = rows as u32;")?;
5695    writeln!(code, "        let c: u32 = cols as u32;")?;
5696    writeln!(
5697        code,
5698        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5699    )?;
5700    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5701    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5702    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5703    writeln!(
5704        code,
5705        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5706    )?;
5707    writeln!(
5708        code,
5709        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5710    )?;
5711    writeln!(
5712        code,
5713        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5714    )?;
5715    writeln!(
5716        code,
5717        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5718    )?;
5719    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5720    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5721    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5722    writeln!(
5723        code,
5724        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5725    )?;
5726    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5727    writeln!(code, "    }}")?;
5728    writeln!(code)?;
5729
5730    // dispatch_add_bias_batch — Qwen2 QKV bias broadcast-add after fused qkv matmul.
5731    if config.qkv_bias {
5732        writeln!(
5733            code,
5734            "    /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5735        )?;
5736        writeln!(
5737            code,
5738            "    fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5739        )?;
5740        writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5741        writeln!(code, "        let r: u32 = rows as u32;")?;
5742        writeln!(
5743            code,
5744            "        enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5745        )?;
5746        writeln!(code, "        enc.set_buffer(0, Some(out), 0);")?;
5747        writeln!(code, "        enc.set_buffer(1, Some(bias), 0);")?;
5748        writeln!(
5749            code,
5750            "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5751        )?;
5752        writeln!(
5753            code,
5754            "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5755        )?;
5756        writeln!(code, "        let total = (num_tokens * rows) as u64;")?;
5757        writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5758        writeln!(
5759            code,
5760            "        let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5761        )?;
5762        writeln!(
5763            code,
5764            "        enc.dispatch_thread_groups(grid_size, tg_size);"
5765        )?;
5766        writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5767        writeln!(code, "    }}")?;
5768        writeln!(code)?;
5769    }
5770
5771    // dispatch_rope_batch
5772    writeln!(
5773        code,
5774        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5775    )?;
5776    writeln!(
5777        code,
5778        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5779    )?;
5780    writeln!(
5781        code,
5782        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5783    )?;
5784    writeln!(
5785        code,
5786        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5787    )?;
5788    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5789    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5790    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5791    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5792    writeln!(
5793        code,
5794        "        let pairs_per_token = num_heads * (head_dim / 2);"
5795    )?;
5796    writeln!(
5797        code,
5798        "        let total_pairs = num_tokens * pairs_per_token;"
5799    )?;
5800    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5801    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5802    // we need to pass the buffer at the right byte offset for each token's data.
5803    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5804    // but our layout is data[token * row_stride + data_float_offset + ...].
5805    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5806    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5807    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5808    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5809    //
5810    // Simplest approach: use the single-token rope kernel for each token in a loop.
5811    // This is still efficient because we're dispatching all within the same command encoder.
5812    writeln!(
5813        code,
5814        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5815    )?;
5816    writeln!(code, "        for t in 0..num_tokens {{")?;
5817    writeln!(
5818        code,
5819        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5820    )?;
5821    writeln!(
5822        code,
5823        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5824    )?;
5825    writeln!(
5826        code,
5827        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5828    )?;
5829    writeln!(
5830        code,
5831        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5832    )?;
5833    writeln!(
5834        code,
5835        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5836    )?;
5837    writeln!(
5838        code,
5839        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5840    )?;
5841    writeln!(
5842        code,
5843        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5844    )?;
5845    writeln!(
5846        code,
5847        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5848    )?;
5849    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5850    writeln!(
5851        code,
5852        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5853    )?;
5854    writeln!(
5855        code,
5856        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5857    )?;
5858    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5859    writeln!(code, "        }}")?;
5860    writeln!(code, "    }}")?;
5861    writeln!(code)?;
5862
5863    // dispatch_silu_mul_fused_batch
5864    writeln!(
5865        code,
5866        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5867    )?;
5868    writeln!(
5869        code,
5870        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5871    )?;
5872    writeln!(code, "        let count: u32 = n as u32;")?;
5873    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5874    writeln!(
5875        code,
5876        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5877    )?;
5878    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5879    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5880    writeln!(
5881        code,
5882        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5883    )?;
5884    writeln!(
5885        code,
5886        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5887    )?;
5888    writeln!(code, "        let total = num_tokens * n;")?;
5889    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5890    writeln!(
5891        code,
5892        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5893    )?;
5894    writeln!(
5895        code,
5896        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5897    )?;
5898    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5899    writeln!(code, "    }}")?;
5900    writeln!(code)?;
5901
5902    // dispatch_add_inplace_batch_n (add n elements in-place)
5903    writeln!(
5904        code,
5905        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5906    )?;
5907    writeln!(
5908        code,
5909        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5910    )?;
5911    writeln!(code, "        let count: u32 = total_n as u32;")?;
5912    writeln!(
5913        code,
5914        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5915    )?;
5916    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5917    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5918    writeln!(
5919        code,
5920        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5921    )?;
5922    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5923    writeln!(
5924        code,
5925        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5926    )?;
5927    writeln!(
5928        code,
5929        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5930    )?;
5931    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5932    writeln!(code, "    }}")?;
5933    writeln!(code)?;
5934
5935    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
5936    writeln!(
5937        code,
5938        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
5939    )?;
5940    writeln!(
5941        code,
5942        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5943    )?;
5944    writeln!(code, "        let n: u32 = count as u32;")?;
5945    writeln!(
5946        code,
5947        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5948    )?;
5949    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5950    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5951    writeln!(
5952        code,
5953        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5954    )?;
5955    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5956    writeln!(
5957        code,
5958        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5959    )?;
5960    writeln!(
5961        code,
5962        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5963    )?;
5964    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5965    writeln!(code, "    }}")?;
5966    writeln!(code)?;
5967
5968    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
5969    writeln!(
5970        code,
5971        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
5972    )?;
5973    writeln!(
5974        code,
5975        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5976    )?;
5977    writeln!(code, "        let n: u32 = count as u32;")?;
5978    writeln!(
5979        code,
5980        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5981    )?;
5982    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5983    writeln!(
5984        code,
5985        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5986    )?;
5987    writeln!(
5988        code,
5989        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5990    )?;
5991    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5992    writeln!(
5993        code,
5994        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5995    )?;
5996    writeln!(
5997        code,
5998        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5999    )?;
6000    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6001    writeln!(code, "    }}")?;
6002    writeln!(code)?;
6003
6004    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
6005    writeln!(
6006        code,
6007        "    /// Copy from src at byte offset to dst at float offset."
6008    )?;
6009    writeln!(
6010        code,
6011        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6012    )?;
6013    writeln!(code, "        let n: u32 = count as u32;")?;
6014    writeln!(
6015        code,
6016        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6017    )?;
6018    writeln!(
6019        code,
6020        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6021    )?;
6022    writeln!(
6023        code,
6024        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6025    )?;
6026    writeln!(
6027        code,
6028        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6029    )?;
6030    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6031    writeln!(
6032        code,
6033        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6034    )?;
6035    writeln!(
6036        code,
6037        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6038    )?;
6039    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6040    writeln!(code, "    }}")?;
6041    writeln!(code)?;
6042
6043    // dispatch_copy_kv_batch
6044    writeln!(
6045        code,
6046        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6047    )?;
6048    writeln!(
6049        code,
6050        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6051    )?;
6052    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6053    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6054    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6055    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6056    writeln!(code, "        let so: u32 = src_offset as u32;")?;
6057    writeln!(
6058        code,
6059        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6060    )?;
6061    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6062    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
6063    writeln!(
6064        code,
6065        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6066    )?;
6067    writeln!(
6068        code,
6069        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6070    )?;
6071    writeln!(
6072        code,
6073        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6074    )?;
6075    writeln!(
6076        code,
6077        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6078    )?;
6079    writeln!(
6080        code,
6081        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6082    )?;
6083    writeln!(code, "        let total = num_tokens * kv_dim;")?;
6084    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6085    writeln!(
6086        code,
6087        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6088    )?;
6089    writeln!(
6090        code,
6091        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6092    )?;
6093    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6094    writeln!(code, "    }}")?;
6095    writeln!(code)?;
6096
6097    // dispatch_attention_batch
6098    writeln!(
6099        code,
6100        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
6101    )?;
6102    writeln!(
6103        code,
6104        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6105    )?;
6106    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6107    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6108    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
6109    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6110    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6111    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
6112    // Attention kernel selection:
6113    //   * Legacy `attention_batch` materializes scores[4096] in threadgroup memory
6114    //     and uses scalar simdgroup reductions.  Fast at short seq_len, no MMA.
6115    //   * `attention_flash_batch` streams K/V with online softmax; no seq cap,
6116    //     scalar math, ~7-14 % slower than legacy at long contexts (no MMA).
6117    //   * `attention_mma_flash_batch` adds hardware simdgroup_matrix<half, 8, 8>
6118    //     MMA for both Q·K^T and P·V, processing Q_BLOCK=8 tokens per TG.
6119    //     Default path when HEAD_DIM ≤ 128 and num_tokens ≥ 8 (verified on
6120    //     Llama, Qwen2.5, Phi-3).  Set FORGE_MMA_ATTN=0 to force legacy.
6121    writeln!(code, "        let max_seq = base_pos + num_tokens;")?;
6122    writeln!(code, "        let _ = max_seq;")?;
6123    writeln!(
6124        code,
6125        "        let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6126    )?;
6127    writeln!(code, "            .map(|v| v == \"0\").unwrap_or(false);")?;
6128    writeln!(
6129        code,
6130        "        let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6131    )?;
6132    writeln!(code, "        if use_mma_flash {{")?;
6133    writeln!(
6134        code,
6135        "            let pipe = &self.attention_mma_flash_batch_pipeline;"
6136    )?;
6137    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
6138    writeln!(code, "            enc.set_buffer(0, Some(q_buf), 0);")?;
6139    writeln!(code, "            enc.set_buffer(1, Some(k_cache), 0);")?;
6140    writeln!(code, "            enc.set_buffer(2, Some(v_cache), 0);")?;
6141    writeln!(code, "            enc.set_buffer(3, Some(output), 0);")?;
6142    writeln!(
6143        code,
6144        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6145    )?;
6146    writeln!(
6147        code,
6148        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6149    )?;
6150    writeln!(
6151        code,
6152        "            enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6153    )?;
6154    writeln!(
6155        code,
6156        "            enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6157    )?;
6158    writeln!(
6159        code,
6160        "            enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6161    )?;
6162    writeln!(
6163        code,
6164        "            enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6165    )?;
6166    writeln!(
6167        code,
6168        "            // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6169    )?;
6170    writeln!(code, "            let tg_size = MTLSize::new(128, 1, 1);")?;
6171    writeln!(
6172        code,
6173        "            let q_blocks = ((num_tokens + 7) / 8) as u64;"
6174    )?;
6175    writeln!(
6176        code,
6177        "            let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6178    )?;
6179    writeln!(
6180        code,
6181        "            enc.dispatch_thread_groups(grid_size, tg_size);"
6182    )?;
6183    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6184    writeln!(code, "            return;")?;
6185    writeln!(code, "        }}")?;
6186    writeln!(code, "        let pipe = &self.attention_batch_pipeline;")?;
6187    writeln!(code, "        enc.set_compute_pipeline_state(pipe);")?;
6188    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
6189    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
6190    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
6191    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
6192    writeln!(
6193        code,
6194        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6195    )?;
6196    writeln!(
6197        code,
6198        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6199    )?;
6200    writeln!(
6201        code,
6202        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6203    )?;
6204    writeln!(
6205        code,
6206        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6207    )?;
6208    writeln!(
6209        code,
6210        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6211    )?;
6212    writeln!(
6213        code,
6214        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6215    )?;
6216    writeln!(
6217        code,
6218        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6219    )?;
6220    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6221    writeln!(
6222        code,
6223        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6224    )?;
6225    writeln!(
6226        code,
6227        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6228    )?;
6229    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6230    writeln!(code, "    }}")?;
6231    writeln!(code)?;
6232
6233    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
6234    writeln!(
6235        code,
6236        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6237    )?;
6238    writeln!(
6239        code,
6240        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6241    )?;
6242    writeln!(
6243        code,
6244        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6245    )?;
6246    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6247    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6248    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
6249    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6250    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6251    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
6252    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
6253    writeln!(
6254        code,
6255        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6256    )?;
6257    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
6258    writeln!(
6259        code,
6260        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6261    )?;
6262    writeln!(
6263        code,
6264        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6265    )?;
6266    writeln!(
6267        code,
6268        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6269    )?;
6270    writeln!(
6271        code,
6272        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6273    )?;
6274    writeln!(
6275        code,
6276        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6277    )?;
6278    writeln!(
6279        code,
6280        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6281    )?;
6282    writeln!(
6283        code,
6284        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6285    )?;
6286    writeln!(
6287        code,
6288        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6289    )?;
6290    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6291    writeln!(
6292        code,
6293        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6294    )?;
6295    writeln!(
6296        code,
6297        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6298    )?;
6299    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6300    writeln!(code, "    }}")?;
6301    writeln!(code)?;
6302
6303    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
6304    writeln!(
6305        code,
6306        "    /// Dispatch fused K+V cache copy in one kernel launch."
6307    )?;
6308    writeln!(
6309        code,
6310        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6311    )?;
6312    writeln!(
6313        code,
6314        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6315    )?;
6316    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6317    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6318    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6319    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6320    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
6321    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
6322    writeln!(
6323        code,
6324        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6325    )?;
6326    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6327    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
6328    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
6329    writeln!(
6330        code,
6331        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6332    )?;
6333    writeln!(
6334        code,
6335        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6336    )?;
6337    writeln!(
6338        code,
6339        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6340    )?;
6341    writeln!(
6342        code,
6343        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6344    )?;
6345    writeln!(
6346        code,
6347        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6348    )?;
6349    writeln!(
6350        code,
6351        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6352    )?;
6353    writeln!(
6354        code,
6355        "        let total = num_tokens * kv_dim * 2;  // K + V"
6356    )?;
6357    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6358    writeln!(
6359        code,
6360        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6361    )?;
6362    writeln!(
6363        code,
6364        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6365    )?;
6366    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6367    writeln!(code, "    }}")?;
6368
6369    writeln!(code, "}}")?;
6370    writeln!(code)?;
6371
6372    Ok(())
6373}
6374
6375fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6376    writeln!(
6377        code,
6378        "// ── Helper functions ──────────────────────────────────"
6379    )?;
6380    writeln!(code)?;
6381    writeln!(
6382        code,
6383        "/// Create a compute pipeline from a named function in the Metal library."
6384    )?;
6385    writeln!(
6386        code,
6387        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6388    )?;
6389    writeln!(
6390        code,
6391        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6392    )?;
6393    writeln!(
6394        code,
6395        "    device.new_compute_pipeline_state_with_function(&func)"
6396    )?;
6397    writeln!(
6398        code,
6399        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6400    )?;
6401    writeln!(code, "}}")?;
6402    writeln!(code)?;
6403
6404    Ok(())
6405}
6406
6407// ---------------------------------------------------------------------------
6408// main.rs generation
6409// ---------------------------------------------------------------------------
6410
6411fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6412    let _sanitized = sanitize_name(model_name);
6413    let _vocab = config.vocab_size;
6414
6415    let mut code = String::with_capacity(16 * 1024);
6416    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6417    writeln!(
6418        code,
6419        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6420    )?;
6421    writeln!(code)?;
6422    writeln!(code, "mod model;")?;
6423    writeln!(code)?;
6424    writeln!(code, "use std::io::Write;")?;
6425    writeln!(code, "use std::time::Instant;")?;
6426    writeln!(code, "use serde::Deserialize;")?;
6427    writeln!(code)?;
6428
6429    // -- main function --
6430    writeln!(code, "fn main() {{")?;
6431    writeln!(
6432        code,
6433        "    let args: Vec<String> = std::env::args().collect();"
6434    )?;
6435    writeln!(code)?;
6436    writeln!(
6437        code,
6438        "    // Detect --serve mode (only requires weights + tokenizer)"
6439    )?;
6440    writeln!(
6441        code,
6442        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
6443    )?;
6444    writeln!(code)?;
6445    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
6446    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6447    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6448    writeln!(code, "        std::process::exit(1);")?;
6449    writeln!(code, "    }}")?;
6450    writeln!(code)?;
6451    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
6452    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6453    writeln!(code, "        std::process::exit(1);")?;
6454    writeln!(code, "    }}")?;
6455    writeln!(code)?;
6456    writeln!(code, "    let weights_path = &args[1];")?;
6457    writeln!(code, "    let tokenizer_path = &args[2];")?;
6458    writeln!(code)?;
6459    writeln!(code, "    // Parse optional flags")?;
6460    writeln!(code, "    let mut max_tokens: usize = 128;")?;
6461    writeln!(code, "    let mut port: u16 = 8080;")?;
6462    writeln!(
6463        code,
6464        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6465    )?;
6466    writeln!(
6467        code,
6468        "    let profile = args.iter().any(|a| a == \"--profile\");"
6469    )?;
6470    writeln!(code, "    let mut i = 3;")?;
6471    writeln!(code, "    while i < args.len() {{")?;
6472    writeln!(
6473        code,
6474        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6475    )?;
6476    writeln!(
6477        code,
6478        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
6479    )?;
6480    writeln!(code, "            i += 2;")?;
6481    writeln!(
6482        code,
6483        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6484    )?;
6485    writeln!(
6486        code,
6487        "            port = args[i + 1].parse().unwrap_or(8080);"
6488    )?;
6489    writeln!(code, "            i += 2;")?;
6490    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
6491    writeln!(code, "            i += 1;")?;
6492    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
6493    writeln!(code, "            i += 1;")?;
6494    writeln!(code, "        }} else {{")?;
6495    writeln!(code, "            i += 1;")?;
6496    writeln!(code, "        }}")?;
6497    writeln!(code, "    }}")?;
6498    writeln!(code)?;
6499
6500    // -- load model (shared by both modes) --
6501    writeln!(
6502        code,
6503        "    // Memory-map weights for zero-copy loading on Apple Silicon"
6504    )?;
6505    writeln!(
6506        code,
6507        "    let weights_file = std::fs::File::open(weights_path)"
6508    )?;
6509    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6510    writeln!(
6511        code,
6512        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6513    )?;
6514    writeln!(code)?;
6515    writeln!(code, "    // Load tokenizer")?;
6516    writeln!(
6517        code,
6518        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6519    )?;
6520    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6521    writeln!(code)?;
6522    writeln!(code, "    // Create Metal model")?;
6523    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
6524    writeln!(
6525        code,
6526        "    let mut model = model::MetalModel::new(&weights_mmap);"
6527    )?;
6528    writeln!(code)?;
6529
6530    // -- branch: serve vs CLI --
6531    writeln!(code, "    if serve_mode {{")?;
6532    writeln!(code, "        serve(model, tokenizer, port);")?;
6533    writeln!(code, "    }} else {{")?;
6534    writeln!(code, "        let prompt = &args[3];")?;
6535    writeln!(
6536        code,
6537        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6538    )?;
6539    writeln!(code, "    }}")?;
6540    writeln!(code, "}}")?;
6541    writeln!(code)?;
6542
6543    // -- cli_mode function --
6544    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6545    writeln!(code, "    // Tokenize prompt")?;
6546    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
6547    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6548    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6549    writeln!(code)?;
6550    writeln!(
6551        code,
6552        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6553    )?;
6554    writeln!(
6555        code,
6556        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
6557    )?;
6558    writeln!(
6559        code,
6560        "    // The last token uses synchronous forward() to get logits."
6561    )?;
6562    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
6563    writeln!(code, "    let prefill_start = Instant::now();")?;
6564    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
6565    writeln!(
6566        code,
6567        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6568    )?;
6569    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
6570    writeln!(code, "    }} else {{")?;
6571    writeln!(code, "        model.forward(prompt_tokens[0])")?;
6572    writeln!(code, "    }};")?;
6573    writeln!(
6574        code,
6575        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6576    )?;
6577    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
6578    writeln!(
6579        code,
6580        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6581    )?;
6582    writeln!(
6583        code,
6584        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6585    )?;
6586    writeln!(code)?;
6587    writeln!(code, "    // Generate tokens")?;
6588    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6589    writeln!(code, "    let gen_start = Instant::now();")?;
6590    writeln!(code, "    let mut generated_count: usize = 0;")?;
6591    writeln!(code)?;
6592    writeln!(code, "    for _ in 0..max_tokens {{")?;
6593    writeln!(
6594        code,
6595        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6596    )?;
6597    writeln!(code, "            if !quiet {{")?;
6598    writeln!(code, "                print!(\"{{}}\", text);")?;
6599    writeln!(code, "                std::io::stdout().flush().ok();")?;
6600    writeln!(code, "            }}")?;
6601    writeln!(code, "        }}")?;
6602    writeln!(code, "        generated_count += 1;")?;
6603    writeln!(code)?;
6604    writeln!(
6605        code,
6606        "        // Use profiling forward for first token when --profile is set"
6607    )?;
6608    writeln!(
6609        code,
6610        "        let logits = if profile && generated_count == 1 {{"
6611    )?;
6612    writeln!(code, "            model.forward_profile(next_token)")?;
6613    writeln!(code, "        }} else {{")?;
6614    writeln!(code, "            model.forward(next_token)")?;
6615    writeln!(code, "        }};")?;
6616    writeln!(code, "        next_token = argmax(&logits);")?;
6617    writeln!(code)?;
6618    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
6619    writeln!(code, "        if next_token == 2 {{")?;
6620    writeln!(code, "            break;")?;
6621    writeln!(code, "        }}")?;
6622    writeln!(code)?;
6623    writeln!(
6624        code,
6625        "        // Yield between tokens to reduce sustained GPU thermal load."
6626    )?;
6627    writeln!(
6628        code,
6629        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6630    )?;
6631    writeln!(
6632        code,
6633        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6634    )?;
6635    writeln!(
6636        code,
6637        "        // briefly, providing a micro-break that helps sustain peak throughput."
6638    )?;
6639    writeln!(code, "        std::thread::yield_now();")?;
6640    writeln!(code, "    }}")?;
6641    writeln!(code, "    if !quiet {{")?;
6642    writeln!(code, "        println!();")?;
6643    writeln!(code, "    }}")?;
6644    writeln!(
6645        code,
6646        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6647    )?;
6648    writeln!(
6649        code,
6650        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6651    )?;
6652    writeln!(
6653        code,
6654        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6655    )?;
6656    writeln!(code, "}}")?;
6657    writeln!(code)?;
6658
6659    // -- argmax helper --
6660    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6661    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6662    writeln!(code, "    logits.iter()")?;
6663    writeln!(code, "        .enumerate()")?;
6664    writeln!(
6665        code,
6666        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6667    )?;
6668    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6669    writeln!(code, "        .unwrap_or(0)")?;
6670    writeln!(code, "}}")?;
6671    writeln!(code)?;
6672
6673    // -- Request/Response types for OpenAI API --
6674    writeln!(
6675        code,
6676        "// -----------------------------------------------------------------------"
6677    )?;
6678    writeln!(code, "// OpenAI-compatible API server")?;
6679    writeln!(
6680        code,
6681        "// -----------------------------------------------------------------------"
6682    )?;
6683    writeln!(code)?;
6684    writeln!(code, "#[derive(Deserialize)]")?;
6685    writeln!(code, "struct ChatRequest {{")?;
6686    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6687    writeln!(code, "    #[serde(default)]")?;
6688    writeln!(code, "    stream: Option<bool>,")?;
6689    writeln!(code, "    #[serde(default)]")?;
6690    writeln!(code, "    max_tokens: Option<usize>,")?;
6691    writeln!(code, "    #[serde(default)]")?;
6692    writeln!(code, "    temperature: Option<f32>,")?;
6693    writeln!(code, "    #[serde(default)]")?;
6694    writeln!(code, "    model: Option<String>,")?;
6695    writeln!(code, "}}")?;
6696    writeln!(code)?;
6697    writeln!(code, "#[derive(Deserialize)]")?;
6698    writeln!(code, "struct ChatMessage {{")?;
6699    writeln!(code, "    role: String,")?;
6700    writeln!(code, "    content: String,")?;
6701    writeln!(code, "}}")?;
6702    writeln!(code)?;
6703
6704    // -- format_chat_messages --
6705    writeln!(
6706        code,
6707        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6708    )?;
6709    writeln!(code, "    let mut prompt = String::new();")?;
6710    writeln!(code, "    for msg in messages {{")?;
6711    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6712    writeln!(code, "    }}")?;
6713    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6714    writeln!(code, "    prompt")?;
6715    writeln!(code, "}}")?;
6716    writeln!(code)?;
6717
6718    // -- prefill helper --
6719    writeln!(
6720        code,
6721        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6722    )?;
6723    writeln!(code, "    let len = tokens.len();")?;
6724    writeln!(code, "    if len > 1 {{")?;
6725    writeln!(
6726        code,
6727        "        model.forward_prefill_batch(&tokens[..len - 1]);"
6728    )?;
6729    writeln!(code, "    }}")?;
6730    writeln!(code, "    model.forward(tokens[len - 1])")?;
6731    writeln!(code, "}}")?;
6732    writeln!(code)?;
6733
6734    // -- serve function --
6735    writeln!(
6736        code,
6737        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6738    )?;
6739    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6740    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
6741    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6742    writeln!(
6743        code,
6744        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6745    )?;
6746    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
6747    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
6748    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
6749    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
6750    writeln!(code)?;
6751    writeln!(code, "    for request in server.incoming_requests() {{")?;
6752    writeln!(code, "        let method = request.method().to_string();")?;
6753    writeln!(code, "        let url = request.url().to_string();")?;
6754    writeln!(code)?;
6755    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
6756
6757    // -- POST /v1/chat/completions --
6758    writeln!(
6759        code,
6760        "            (\"POST\", \"/v1/chat/completions\") => {{"
6761    )?;
6762    writeln!(
6763        code,
6764        "                handle_chat_completion(&mut model, &tokenizer, request);"
6765    )?;
6766    writeln!(code, "            }}")?;
6767
6768    // -- GET /v1/models --
6769    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
6770    writeln!(code, "                let body = serde_json::json!({{")?;
6771    writeln!(code, "                    \"object\": \"list\",")?;
6772    writeln!(code, "                    \"data\": [{{")?;
6773    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
6774    writeln!(code, "                        \"object\": \"model\",")?;
6775    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
6776    writeln!(code, "                    }}]")?;
6777    writeln!(code, "                }});")?;
6778    writeln!(
6779        code,
6780        "                let resp = tiny_http::Response::from_string(body.to_string())"
6781    )?;
6782    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6783    writeln!(code, "                request.respond(resp).ok();")?;
6784    writeln!(code, "            }}")?;
6785
6786    // -- GET /health --
6787    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
6788    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6789    writeln!(code, "                request.respond(resp).ok();")?;
6790    writeln!(code, "            }}")?;
6791
6792    // -- 404 --
6793    writeln!(code, "            _ => {{")?;
6794    writeln!(
6795        code,
6796        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
6797    )?;
6798    writeln!(code, "                    .with_status_code(404);")?;
6799    writeln!(code, "                request.respond(resp).ok();")?;
6800    writeln!(code, "            }}")?;
6801    writeln!(code, "        }}")?;
6802    writeln!(code, "    }}")?;
6803    writeln!(code, "}}")?;
6804    writeln!(code)?;
6805
6806    // -- handle_chat_completion --
6807    writeln!(code, "fn handle_chat_completion(")?;
6808    writeln!(code, "    model: &mut model::MetalModel,")?;
6809    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6810    writeln!(code, "    mut request: tiny_http::Request,")?;
6811    writeln!(code, ") {{")?;
6812    writeln!(code, "    // Read request body")?;
6813    writeln!(code, "    let mut body = String::new();")?;
6814    writeln!(
6815        code,
6816        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6817    )?;
6818    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6819    writeln!(code, "            .with_status_code(400);")?;
6820    writeln!(code, "        request.respond(resp).ok();")?;
6821    writeln!(code, "        return;")?;
6822    writeln!(code, "    }}")?;
6823    writeln!(code)?;
6824    writeln!(code, "    // Parse JSON")?;
6825    writeln!(
6826        code,
6827        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6828    )?;
6829    writeln!(code, "        Ok(r) => r,")?;
6830    writeln!(code, "        Err(e) => {{")?;
6831    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6832    writeln!(code, "                .with_status_code(400);")?;
6833    writeln!(code, "            request.respond(resp).ok();")?;
6834    writeln!(code, "            return;")?;
6835    writeln!(code, "        }}")?;
6836    writeln!(code, "    }};")?;
6837    writeln!(code)?;
6838    writeln!(
6839        code,
6840        "    let prompt = format_chat_messages(&req.messages);"
6841    )?;
6842    writeln!(
6843        code,
6844        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6845    )?;
6846    writeln!(code, "        Ok(e) => e,")?;
6847    writeln!(code, "        Err(e) => {{")?;
6848    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6849    writeln!(code, "                .with_status_code(500);")?;
6850    writeln!(code, "            request.respond(resp).ok();")?;
6851    writeln!(code, "            return;")?;
6852    writeln!(code, "        }}")?;
6853    writeln!(code, "    }};")?;
6854    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6855    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6856    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6857    writeln!(
6858        code,
6859        "    let _temperature = req.temperature.unwrap_or(1.0);"
6860    )?;
6861    writeln!(code)?;
6862
6863    // -- Reset KV cache for each request --
6864    writeln!(code, "    model.reset();")?;
6865    writeln!(code)?;
6866
6867    // -- Prefill with timing --
6868    writeln!(code, "    let prefill_start = Instant::now();")?;
6869    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6870    writeln!(
6871        code,
6872        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6873    )?;
6874    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6875    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6876    writeln!(code)?;
6877
6878    writeln!(code, "    if stream {{")?;
6879
6880    // -- SSE streaming response --
6881    writeln!(
6882        code,
6883        "        // SSE streaming: generate tokens and build SSE body"
6884    )?;
6885    writeln!(code, "        let gen_start = Instant::now();")?;
6886    writeln!(code, "        let mut generated_count: usize = 0;")?;
6887    writeln!(code, "        let mut sse_body = String::new();")?;
6888    writeln!(code, "        for _ in 0..max_tokens {{")?;
6889    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6890    writeln!(
6891        code,
6892        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6893    )?;
6894    writeln!(
6895        code,
6896        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6897    )?;
6898    writeln!(
6899        code,
6900        "                // escaped includes surrounding quotes, strip them"
6901    )?;
6902    writeln!(
6903        code,
6904        "                let inner = &escaped[1..escaped.len()-1];"
6905    )?;
6906    writeln!(code, "                sse_body.push_str(&format!(")?;
6907    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6908    writeln!(code, "                    inner")?;
6909    writeln!(code, "                ));")?;
6910    writeln!(code, "            }}")?;
6911    writeln!(code, "            generated_count += 1;")?;
6912    writeln!(code, "            let logits = model.forward(next_token);")?;
6913    writeln!(code, "            next_token = argmax(&logits);")?;
6914    writeln!(code, "        }}")?;
6915    writeln!(
6916        code,
6917        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6918    )?;
6919    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6920    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
6921    writeln!(code)?;
6922    writeln!(
6923        code,
6924        "        // Final chunk with finish_reason, timing, and DONE sentinel"
6925    )?;
6926    writeln!(code, "        sse_body.push_str(&format!(")?;
6927    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6928    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6929    writeln!(code, "        ));")?;
6930    writeln!(code)?;
6931    writeln!(
6932        code,
6933        "        let resp = tiny_http::Response::from_string(sse_body)"
6934    )?;
6935    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6936    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6937    writeln!(code, "        request.respond(resp).ok();")?;
6938
6939    writeln!(code, "    }} else {{")?;
6940
6941    // -- Non-streaming response --
6942    writeln!(
6943        code,
6944        "        // Non-streaming: generate all tokens, return JSON"
6945    )?;
6946    writeln!(code, "        let gen_start = Instant::now();")?;
6947    writeln!(code, "        let mut generated_count: usize = 0;")?;
6948    writeln!(code, "        let mut generated = String::new();")?;
6949    writeln!(code, "        for _ in 0..max_tokens {{")?;
6950    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6951    writeln!(
6952        code,
6953        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6954    )?;
6955    writeln!(code, "                generated.push_str(&text);")?;
6956    writeln!(code, "            }}")?;
6957    writeln!(code, "            generated_count += 1;")?;
6958    writeln!(code, "            let logits = model.forward(next_token);")?;
6959    writeln!(code, "            next_token = argmax(&logits);")?;
6960    writeln!(code, "        }}")?;
6961    writeln!(
6962        code,
6963        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6964    )?;
6965    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6966    writeln!(code)?;
6967    writeln!(code, "        let resp_json = serde_json::json!({{")?;
6968    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
6969    writeln!(code, "            \"object\": \"chat.completion\",")?;
6970    writeln!(code, "            \"choices\": [{{")?;
6971    writeln!(code, "                \"index\": 0,")?;
6972    writeln!(code, "                \"message\": {{")?;
6973    writeln!(code, "                    \"role\": \"assistant\",")?;
6974    writeln!(code, "                    \"content\": generated")?;
6975    writeln!(code, "                }},")?;
6976    writeln!(code, "                \"finish_reason\": \"stop\"")?;
6977    writeln!(code, "            }}],")?;
6978    writeln!(code, "            \"usage\": {{")?;
6979    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
6980    writeln!(
6981        code,
6982        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
6983    )?;
6984    writeln!(
6985        code,
6986        "                \"generation_tokens\": generated_count,"
6987    )?;
6988    writeln!(
6989        code,
6990        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
6991    )?;
6992    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
6993    writeln!(code, "            }}")?;
6994    writeln!(code, "        }});")?;
6995    writeln!(
6996        code,
6997        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
6998    )?;
6999    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7000    writeln!(code, "        request.respond(resp).ok();")?;
7001    writeln!(code, "    }}")?;
7002    writeln!(code, "}}")?;
7003
7004    Ok(code)
7005}
7006
7007// ---------------------------------------------------------------------------
7008// Tests
7009// ---------------------------------------------------------------------------
7010
7011#[cfg(test)]
7012mod tests {
7013    use super::*;
7014    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7015
7016    fn minimal_config() -> ModelConfig {
7017        ModelConfig {
7018            architecture: Architecture::Llama,
7019            hidden_size: 64,
7020            intermediate_size: 128,
7021            num_layers: 2,
7022            num_attention_heads: 4,
7023            num_kv_heads: 4,
7024            head_dim: 16,
7025            vocab_size: 256,
7026            max_seq_len: 512,
7027            rms_norm_eps: 1e-5,
7028            rope_theta: 10000.0,
7029            dtype: DType::F32,
7030            sliding_window_size: None,
7031            qkv_bias: false,
7032        }
7033    }
7034
7035    fn minimal_graph() -> Graph {
7036        Graph::new("test-metal").with_config(minimal_config())
7037    }
7038
7039    #[test]
7040    fn generate_metal_project_creates_files() {
7041        let dir = tempfile::tempdir().unwrap();
7042        let graph = minimal_graph();
7043        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7044
7045        assert!(
7046            dir.path().join("Cargo.toml").exists(),
7047            "Cargo.toml should be created"
7048        );
7049        assert!(
7050            dir.path().join("src/model.rs").exists(),
7051            "src/model.rs should be created"
7052        );
7053        assert!(
7054            dir.path().join("src/main.rs").exists(),
7055            "src/main.rs should be created"
7056        );
7057        assert!(
7058            dir.path().join("shaders/kernels.metal").exists(),
7059            "shaders/kernels.metal should be created"
7060        );
7061    }
7062
7063    #[test]
7064    fn generated_cargo_toml_has_metal_dep() {
7065        let toml = generate_cargo_toml("my-model");
7066        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7067        assert!(
7068            toml.contains("tokenizers"),
7069            "Cargo.toml should depend on tokenizers"
7070        );
7071        assert!(
7072            toml.contains("memmap2"),
7073            "Cargo.toml should depend on memmap2"
7074        );
7075        assert!(toml.contains("half"), "Cargo.toml should depend on half");
7076    }
7077
7078    #[test]
7079    fn generated_model_rs_contains_metal_code() {
7080        let config = minimal_config();
7081        let model_rs = generate_model_rs(&config).unwrap();
7082
7083        assert!(
7084            model_rs.contains("pub struct MetalModel"),
7085            "model.rs should define MetalModel struct"
7086        );
7087        assert!(
7088            model_rs.contains("matmul_pipeline: ComputePipelineState"),
7089            "MetalModel should have matmul_pipeline field"
7090        );
7091        assert!(
7092            model_rs.contains("Device::system_default()"),
7093            "model.rs should use Metal device"
7094        );
7095        assert!(
7096            model_rs.contains("new_library_with_source"),
7097            "model.rs should compile Metal shaders"
7098        );
7099        assert!(
7100            model_rs.contains("fn new(weights: &[u8])"),
7101            "MetalModel should implement new()"
7102        );
7103        assert!(
7104            model_rs.contains("fn forward(&mut self, token_id: u32)"),
7105            "MetalModel should implement forward()"
7106        );
7107    }
7108
7109    #[test]
7110    fn generated_shaders_contain_kernel_names() {
7111        let shaders = generate_metal_shaders(&minimal_config());
7112
7113        assert!(
7114            shaders.contains("kernel void matmul_vec"),
7115            "shaders should contain matmul_vec kernel"
7116        );
7117        assert!(
7118            shaders.contains("kernel void rms_norm"),
7119            "shaders should contain rms_norm kernel"
7120        );
7121        assert!(
7122            shaders.contains("kernel void rope"),
7123            "shaders should contain rope kernel"
7124        );
7125        assert!(
7126            shaders.contains("kernel void softmax"),
7127            "shaders should contain softmax kernel"
7128        );
7129        assert!(
7130            shaders.contains("kernel void silu_mul("),
7131            "shaders should contain silu_mul kernel"
7132        );
7133        assert!(
7134            shaders.contains("kernel void silu_mul_fused"),
7135            "shaders should contain silu_mul_fused kernel"
7136        );
7137        assert!(
7138            shaders.contains("kernel void elementwise_add"),
7139            "shaders should contain elementwise_add kernel"
7140        );
7141        assert!(
7142            shaders.contains("kernel void attention"),
7143            "shaders should contain attention kernel"
7144        );
7145        assert!(
7146            shaders.contains("kernel void add_inplace"),
7147            "shaders should contain add_inplace kernel"
7148        );
7149        assert!(
7150            shaders.contains("kernel void copy_buffer"),
7151            "shaders should contain copy_buffer kernel"
7152        );
7153        assert!(
7154            shaders.contains("kernel void copy_offset"),
7155            "shaders should contain copy_offset kernel"
7156        );
7157    }
7158
7159    #[test]
7160    fn generated_shaders_use_simdgroup_features() {
7161        let shaders = generate_metal_shaders(&minimal_config());
7162
7163        assert!(
7164            shaders.contains("threadgroup_barrier"),
7165            "shaders should use threadgroup barriers"
7166        );
7167        assert!(
7168            shaders.contains("threadgroup float"),
7169            "shaders should use threadgroup shared memory"
7170        );
7171        assert!(
7172            shaders.contains("thread_index_in_threadgroup"),
7173            "shaders should use threadgroup indexing"
7174        );
7175        assert!(
7176            shaders.contains("simd_sum"),
7177            "shaders should use simd_sum for warp-level reduction"
7178        );
7179        assert!(
7180            shaders.contains("simd_max"),
7181            "attention kernel should use simd_max for cooperative softmax"
7182        );
7183        assert!(
7184            shaders.contains("thread_index_in_simdgroup"),
7185            "shaders should use simdgroup lane indexing"
7186        );
7187        assert!(
7188            shaders.contains("simdgroup_index_in_threadgroup"),
7189            "shaders should use simdgroup indexing within threadgroup"
7190        );
7191        assert!(
7192            shaders.contains("float4"),
7193            "matmul_vec should use float4 vectorized loads"
7194        );
7195    }
7196
7197    #[test]
7198    fn generated_main_rs_has_tokenizer_usage() {
7199        let config = minimal_config();
7200        let main_rs = generate_main_rs("test-model", &config).unwrap();
7201
7202        assert!(
7203            main_rs.contains("tokenizers::Tokenizer"),
7204            "main.rs should use tokenizers crate"
7205        );
7206        assert!(
7207            main_rs.contains("MetalModel::new"),
7208            "main.rs should call MetalModel::new"
7209        );
7210        assert!(
7211            main_rs.contains("model.forward"),
7212            "main.rs should call model.forward"
7213        );
7214        assert!(
7215            main_rs.contains("memmap2"),
7216            "main.rs should use memmap2 for zero-copy weight loading"
7217        );
7218    }
7219
7220    #[test]
7221    fn missing_config_returns_error() {
7222        let dir = tempfile::tempdir().unwrap();
7223        let graph = Graph::new("no-config");
7224        let result = generate_metal_project(&graph, dir.path(), "fail");
7225        assert!(
7226            matches!(result, Err(MetalCodegenError::MissingConfig)),
7227            "should fail with MissingConfig when graph has no config"
7228        );
7229    }
7230
7231    #[test]
7232    fn sanitize_name_works() {
7233        assert_eq!(sanitize_name("My Model!"), "my-model");
7234        assert_eq!(sanitize_name("test_model"), "test-model");
7235        assert_eq!(sanitize_name("simple"), "simple");
7236    }
7237
7238    #[test]
7239    fn generated_forward_uses_single_command_buffer() {
7240        let config = minimal_config();
7241        let model_rs = generate_model_rs(&config).unwrap();
7242
7243        // The forward function should create exactly one command buffer.
7244        // Use the exact signature to avoid matching forward_prefill/forward_profile.
7245        let forward_start = model_rs
7246            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7247            .unwrap();
7248        let forward_body = &model_rs[forward_start..];
7249        // End at the next pub/private method
7250        let forward_end = forward_body
7251            .find("\n    pub fn forward_profile")
7252            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7253            .or_else(|| forward_body.find("\n    fn dispatch_"))
7254            .unwrap_or(forward_body.len());
7255        let forward_code = &forward_body[..forward_end];
7256
7257        // Should have exactly one new_command_buffer call
7258        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7259        assert_eq!(
7260            cmd_buf_count, 1,
7261            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7262        );
7263
7264        // Should have exactly one commit call
7265        let commit_count = forward_code.matches("cmd.commit()").count();
7266        assert_eq!(
7267            commit_count, 1,
7268            "forward() should commit exactly once, found {commit_count}"
7269        );
7270
7271        // Should wait: once for cmd + possibly once for prev_cmd drain
7272        let wait_count = forward_code.matches("wait_until_completed()").count();
7273        assert!(
7274            wait_count >= 1 && wait_count <= 2,
7275            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7276        );
7277    }
7278
7279    #[test]
7280    fn generated_model_has_preallocated_working_buffers() {
7281        let config = minimal_config();
7282        let model_rs = generate_model_rs(&config).unwrap();
7283
7284        for buf_name in &[
7285            "normed_buf",
7286            "qkv_buf",
7287            "attn_out_buf",
7288            "attn_proj_buf",
7289            "gate_up_buf",
7290            "ffn_hidden_buf",
7291            "ffn_out_buf",
7292            "add_tmp_buf",
7293        ] {
7294            assert!(
7295                model_rs.contains(&format!("{buf_name}: Buffer")),
7296                "MetalModel should have pre-allocated {buf_name} field"
7297            );
7298        }
7299    }
7300
7301    #[test]
7302    fn generated_dispatch_helpers_take_compute_encoder_ref() {
7303        let config = minimal_config();
7304        let model_rs = generate_model_rs(&config).unwrap();
7305
7306        for method in &[
7307            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7308            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7309            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7310            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7311            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7312            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7313            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7314            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7315            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7316            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7317            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7318            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7319            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7320        ] {
7321            assert!(
7322                model_rs.contains(method),
7323                "model.rs should contain dispatch helper: {method}"
7324            );
7325        }
7326    }
7327
7328    #[test]
7329    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7330        let config = minimal_config();
7331        let model_rs = generate_model_rs(&config).unwrap();
7332
7333        // Find dispatch helpers section and check none create their own encoders
7334        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7335        let helpers_code = &model_rs[helpers_start..];
7336
7337        // None of the dispatch_ helpers should call new_command_buffer
7338        assert!(
7339            !helpers_code.contains("self.queue.new_command_buffer()"),
7340            "dispatch helpers should not create their own command buffers"
7341        );
7342
7343        // None should create their own compute encoders
7344        assert!(
7345            !helpers_code.contains("new_compute_command_encoder()"),
7346            "dispatch helpers should not create their own compute encoders"
7347        );
7348
7349        // None should call end_encoding
7350        assert!(
7351            !helpers_code.contains("end_encoding()"),
7352            "dispatch helpers should not call end_encoding"
7353        );
7354
7355        // None should call commit or wait
7356        assert!(
7357            !helpers_code.contains(".commit()"),
7358            "dispatch helpers should not commit command buffers"
7359        );
7360        assert!(
7361            !helpers_code.contains("wait_until_completed"),
7362            "dispatch helpers should not wait on command buffers"
7363        );
7364    }
7365
7366    #[test]
7367    fn generated_forward_batches_compute_encoders() {
7368        let config = minimal_config();
7369        let model_rs = generate_model_rs(&config).unwrap();
7370
7371        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
7372        let forward_start = model_rs
7373            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7374            .unwrap();
7375        let forward_body = &model_rs[forward_start..];
7376        let forward_end = forward_body
7377            .find("\n    pub fn forward_profile")
7378            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7379            .or_else(|| forward_body.find("\n    fn dispatch_"))
7380            .unwrap_or(forward_body.len());
7381        let forward_code = &forward_body[..forward_end];
7382
7383        // Forward should not allocate new buffers
7384        assert!(
7385            !forward_code.contains("device.new_buffer"),
7386            "forward() should not allocate new buffers per call"
7387        );
7388
7389        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
7390        // Copy operations use compute copy kernels instead of blit encoders.
7391        let compute_encoder_count = forward_code
7392            .matches("new_compute_command_encoder()")
7393            .count();
7394        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7395
7396        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
7397        assert_eq!(
7398            compute_encoder_count, 1,
7399            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7400        );
7401        assert_eq!(
7402            blit_encoder_count, 0,
7403            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7404        );
7405    }
7406
7407    #[test]
7408    fn generated_forward_uses_add_inplace() {
7409        let config = minimal_config();
7410        let model_rs = generate_model_rs(&config).unwrap();
7411
7412        // Should use in-place add (no blit copy-back needed)
7413        assert!(
7414            model_rs.contains("dispatch_add_inplace"),
7415            "forward() should use dispatch_add_inplace for residual connections"
7416        );
7417        assert!(
7418            model_rs.contains("add_inplace_pipeline"),
7419            "MetalModel should have add_inplace_pipeline"
7420        );
7421    }
7422
7423    fn minimal_q8_config() -> ModelConfig {
7424        ModelConfig {
7425            architecture: Architecture::Llama,
7426            hidden_size: 64,
7427            intermediate_size: 128,
7428            num_layers: 2,
7429            num_attention_heads: 4,
7430            num_kv_heads: 4,
7431            head_dim: 16,
7432            vocab_size: 256,
7433            max_seq_len: 512,
7434            rms_norm_eps: 1e-5,
7435            rope_theta: 10000.0,
7436            dtype: DType::Q8_0,
7437            sliding_window_size: None,
7438            qkv_bias: false,
7439        }
7440    }
7441
7442    #[test]
7443    fn generated_shaders_contain_q8_kernel() {
7444        let shaders = generate_metal_shaders(&minimal_config());
7445
7446        assert!(
7447            shaders.contains("kernel void matmul_vec_q8"),
7448            "shaders should contain matmul_vec_q8 kernel"
7449        );
7450        assert!(
7451            shaders.contains("device const uchar* matrix"),
7452            "matmul_vec_q8 should accept raw Q8_0 bytes"
7453        );
7454        assert!(
7455            shaders.contains("packed_short4"),
7456            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7457        );
7458        assert!(
7459            shaders.contains("as_type<char2>"),
7460            "matmul_vec_q8 should bitcast short lanes to char2"
7461        );
7462        assert!(
7463            shaders.contains("device const half*"),
7464            "matmul_vec_q8 should read f16 scale via half pointer"
7465        );
7466    }
7467
7468    #[test]
7469    fn generated_model_uses_fused_qkv_projections() {
7470        let config = minimal_config();
7471        let model_rs = generate_model_rs(&config).unwrap();
7472
7473        // Should have fused QKV weight in layer buffers
7474        assert!(
7475            model_rs.contains("qkv_weight: Buffer"),
7476            "LayerBuffers should have fused qkv_weight field"
7477        );
7478        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
7479        assert!(
7480            !model_rs.contains("    q_weight: Buffer"),
7481            "LayerBuffers should not have separate q_weight field"
7482        );
7483        assert!(
7484            !model_rs.contains("    k_weight: Buffer"),
7485            "LayerBuffers should not have separate k_weight field"
7486        );
7487        assert!(
7488            !model_rs.contains("    v_weight: Buffer"),
7489            "LayerBuffers should not have separate v_weight field"
7490        );
7491
7492        // Should have fused gate_up_weight
7493        assert!(
7494            model_rs.contains("gate_up_weight: Buffer"),
7495            "LayerBuffers should have fused gate_up_weight field"
7496        );
7497        // Should NOT have separate gate/up weight fields
7498        assert!(
7499            !model_rs.contains("    gate_weight: Buffer"),
7500            "LayerBuffers should not have separate gate_weight field"
7501        );
7502        assert!(
7503            !model_rs.contains("    up_weight: Buffer"),
7504            "LayerBuffers should not have separate up_weight field"
7505        );
7506
7507        // Should have fused working buffers
7508        assert!(
7509            model_rs.contains("qkv_buf: Buffer"),
7510            "MetalModel should have fused qkv_buf"
7511        );
7512        assert!(
7513            model_rs.contains("gate_up_buf: Buffer"),
7514            "MetalModel should have fused gate_up_buf"
7515        );
7516
7517        // Forward pass should use fused dispatch
7518        assert!(
7519            model_rs.contains("dispatch_silu_mul_fused"),
7520            "forward pass should use dispatch_silu_mul_fused"
7521        );
7522        assert!(
7523            model_rs.contains("dispatch_rope_offset"),
7524            "forward pass should use dispatch_rope_offset for fused QKV"
7525        );
7526        assert!(
7527            model_rs.contains("dispatch_attention_offset"),
7528            "forward pass should use dispatch_attention_offset for fused QKV"
7529        );
7530    }
7531
7532    #[test]
7533    fn q8_model_has_matmul_q8_pipeline() {
7534        let config = minimal_q8_config();
7535        let model_rs = generate_model_rs(&config).unwrap();
7536
7537        assert!(
7538            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7539            "MetalModel should have matmul_q8_pipeline field"
7540        );
7541        assert!(
7542            model_rs.contains("matmul_q8_pipeline,"),
7543            "MetalModel Self should include matmul_q8_pipeline"
7544        );
7545    }
7546
7547    #[test]
7548    fn q8_model_uses_dispatch_matmul_q8() {
7549        let config = minimal_q8_config();
7550        let model_rs = generate_model_rs(&config).unwrap();
7551
7552        assert!(
7553            model_rs.contains("dispatch_matmul_q8"),
7554            "Q8_0 model should use dispatch_matmul_q8 for projections"
7555        );
7556        assert!(
7557            model_rs.contains("fn dispatch_matmul_q8"),
7558            "model.rs should define dispatch_matmul_q8 method"
7559        );
7560    }
7561
7562    #[test]
7563    fn q8_model_loads_raw_bytes_not_dequantized() {
7564        let config = minimal_q8_config();
7565        let model_rs = generate_model_rs(&config).unwrap();
7566
7567        // Should NOT contain dequantization code
7568        assert!(
7569            !model_rs.contains("f16_to_f32"),
7570            "Q8_0 model should not dequantize weights to f32"
7571        );
7572        assert!(
7573            !model_rs.contains("f32_data"),
7574            "Q8_0 model should not create f32 weight data"
7575        );
7576
7577        // Should load raw Q8_0 bytes directly
7578        assert!(
7579            model_rs.contains("total_raw as u64"),
7580            "Q8_0 model should load raw bytes into Metal buffer"
7581        );
7582    }
7583
7584    #[test]
7585    fn q8_model_norms_stay_f32() {
7586        let config = minimal_q8_config();
7587        let model_rs = generate_model_rs(&config).unwrap();
7588
7589        // Norm weights should still use f32 buffers
7590        assert!(
7591            model_rs.contains("let attn_norm = next_f32_buffer"),
7592            "attn_norm should use f32 buffer even for Q8_0 models"
7593        );
7594        assert!(
7595            model_rs.contains("let ffn_norm = next_f32_buffer"),
7596            "ffn_norm should use f32 buffer even for Q8_0 models"
7597        );
7598        assert!(
7599            model_rs.contains("let norm_buf = next_f32_buffer"),
7600            "final norm should use f32 buffer even for Q8_0 models"
7601        );
7602    }
7603
7604    #[test]
7605    fn q8_model_uses_fused_weight_loading() {
7606        let config = minimal_q8_config();
7607        let model_rs = generate_model_rs(&config).unwrap();
7608
7609        // Should use fused Q8 buffer loading for QKV
7610        assert!(
7611            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7612            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7613        );
7614        // Should use fused Q8 buffer loading for gate+up
7615        assert!(
7616            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7617            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7618        );
7619        // Should still use regular q8 buffer for individual weights
7620        assert!(
7621            model_rs.contains("let o_weight = next_q8_buffer"),
7622            "Q8_0 model should use next_q8_buffer for O weight"
7623        );
7624        assert!(
7625            model_rs.contains("let down_weight = next_q8_buffer"),
7626            "Q8_0 model should use next_q8_buffer for down weight"
7627        );
7628    }
7629
7630    #[test]
7631    fn f32_model_does_not_use_q8_dispatch() {
7632        let config = minimal_config();
7633        let model_rs = generate_model_rs(&config).unwrap();
7634
7635        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
7636        let forward_start = model_rs
7637            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7638            .unwrap();
7639        let forward_body = &model_rs[forward_start..];
7640        let forward_end = forward_body
7641            .find("\n    fn dispatch_")
7642            .unwrap_or(forward_body.len());
7643        let forward_code = &forward_body[..forward_end];
7644
7645        assert!(
7646            !forward_code.contains("dispatch_matmul_q8"),
7647            "f32 model forward should not use dispatch_matmul_q8"
7648        );
7649    }
7650
7651    #[test]
7652    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7653        let config = minimal_q8_config();
7654        let model_rs = generate_model_rs(&config).unwrap();
7655
7656        assert!(
7657            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7658            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7659        );
7660    }
7661
7662    #[test]
7663    fn generated_model_has_double_buffered_prefill() {
7664        let config = minimal_config();
7665        let model_rs = generate_model_rs(&config).unwrap();
7666
7667        // MetalModel should have prev_cmd field for double-buffered prefill
7668        assert!(
7669            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7670            "MetalModel should have prev_cmd field for double-buffered prefill"
7671        );
7672
7673        // Should have forward_prefill method
7674        assert!(
7675            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7676            "MetalModel should have forward_prefill method"
7677        );
7678
7679        // forward() should drain prev_cmd at the start
7680        assert!(
7681            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7682            "forward() should drain prev_cmd from previous prefill"
7683        );
7684    }
7685
7686    #[test]
7687    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7688        let config = minimal_config();
7689        let main_rs = generate_main_rs("test-model", &config).unwrap();
7690
7691        assert!(
7692            main_rs.contains("forward_prefill"),
7693            "main.rs should use forward_prefill for intermediate prompt tokens"
7694        );
7695        assert!(
7696            main_rs.contains("double-buffered"),
7697            "main.rs should document double-buffered prefill"
7698        );
7699    }
7700
7701    #[test]
7702    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7703        let shaders = generate_metal_shaders(&minimal_config());
7704
7705        assert!(
7706            shaders.contains("packed_short4"),
7707            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7708        );
7709        assert!(
7710            shaders.contains("d0[0]"),
7711            "matmul_vec_q8 should index the wide pointer for row 0"
7712        );
7713        assert!(
7714            shaders.contains("as_type<char2>"),
7715            "matmul_vec_q8 should bitcast short lanes to char2"
7716        );
7717        assert!(
7718            shaders.contains("dot("),
7719            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7720        );
7721    }
7722
7723    // ── Q4_0 tests ──────────────────────────────────────────────────────
7724
7725    fn minimal_q4_config() -> ModelConfig {
7726        ModelConfig {
7727            architecture: Architecture::Llama,
7728            hidden_size: 64,
7729            intermediate_size: 128,
7730            num_layers: 2,
7731            num_attention_heads: 4,
7732            num_kv_heads: 4,
7733            head_dim: 16,
7734            vocab_size: 256,
7735            max_seq_len: 512,
7736            rms_norm_eps: 1e-5,
7737            rope_theta: 10000.0,
7738            dtype: DType::Q4_0,
7739            sliding_window_size: None,
7740            qkv_bias: false,
7741        }
7742    }
7743
7744    #[test]
7745    fn generated_shaders_contain_q4_kernel() {
7746        let shaders = generate_metal_shaders(&minimal_config());
7747
7748        assert!(
7749            shaders.contains("kernel void matmul_vec_q4"),
7750            "shaders should contain matmul_vec_q4 kernel"
7751        );
7752        assert!(
7753            shaders.contains("Q4_ROWS_PER_TG"),
7754            "shaders should define Q4_ROWS_PER_TG constant"
7755        );
7756        assert!(
7757            shaders.contains("Q4_ROWS_PER_SG"),
7758            "shaders should define Q4_ROWS_PER_SG constant"
7759        );
7760    }
7761
7762    #[test]
7763    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7764        let shaders = generate_metal_shaders(&minimal_config());
7765
7766        // Q4_0 kernel should use uchar4 for packed byte loads
7767        assert!(
7768            shaders.contains("uchar4"),
7769            "matmul_vec_q4 should use uchar4 for packed byte loads"
7770        );
7771        // Should unpack nibbles with &0xF and >>4
7772        assert!(
7773            shaders.contains("&0xF"),
7774            "matmul_vec_q4 should extract low nibble with &0xF"
7775        );
7776        assert!(
7777            shaders.contains(">>4"),
7778            "matmul_vec_q4 should extract high nibble with >>4"
7779        );
7780        // Should subtract 8 to convert unsigned to signed
7781        assert!(
7782            shaders.contains("-8)"),
7783            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7784        );
7785        // Should use 18-byte block size
7786        assert!(
7787            shaders.contains("blk * 18"),
7788            "matmul_vec_q4 should use 18-byte block stride"
7789        );
7790    }
7791
7792    #[test]
7793    fn q4_model_has_matmul_q4_pipeline() {
7794        let config = minimal_q4_config();
7795        let model_rs = generate_model_rs(&config).unwrap();
7796
7797        assert!(
7798            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7799            "MetalModel should have matmul_q4_pipeline field"
7800        );
7801        assert!(
7802            model_rs.contains("matmul_q4_pipeline,"),
7803            "MetalModel Self should include matmul_q4_pipeline"
7804        );
7805    }
7806
7807    #[test]
7808    fn q4_model_uses_dispatch_matmul_q4() {
7809        let config = minimal_q4_config();
7810        let model_rs = generate_model_rs(&config).unwrap();
7811
7812        assert!(
7813            model_rs.contains("dispatch_matmul_q4"),
7814            "Q4_0 model should use dispatch_matmul_q4 for projections"
7815        );
7816        assert!(
7817            model_rs.contains("fn dispatch_matmul_q4"),
7818            "model.rs should define dispatch_matmul_q4 method"
7819        );
7820    }
7821
7822    #[test]
7823    fn q4_model_loads_raw_bytes_not_dequantized() {
7824        let config = minimal_q4_config();
7825        let model_rs = generate_model_rs(&config).unwrap();
7826
7827        // Should NOT contain dequantization code
7828        assert!(
7829            !model_rs.contains("f16_to_f32"),
7830            "Q4_0 model should not dequantize weights to f32"
7831        );
7832
7833        // Should load raw Q4_0 bytes directly
7834        assert!(
7835            model_rs.contains("total_raw as u64"),
7836            "Q4_0 model should load raw bytes into Metal buffer"
7837        );
7838    }
7839
7840    #[test]
7841    fn q4_model_norms_stay_f32() {
7842        let config = minimal_q4_config();
7843        let model_rs = generate_model_rs(&config).unwrap();
7844
7845        assert!(
7846            model_rs.contains("let attn_norm = next_f32_buffer"),
7847            "attn_norm should use f32 buffer even for Q4_0 models"
7848        );
7849        assert!(
7850            model_rs.contains("let ffn_norm = next_f32_buffer"),
7851            "ffn_norm should use f32 buffer even for Q4_0 models"
7852        );
7853        assert!(
7854            model_rs.contains("let norm_buf = next_f32_buffer"),
7855            "final norm should use f32 buffer even for Q4_0 models"
7856        );
7857    }
7858
7859    #[test]
7860    fn q4_model_uses_fused_weight_loading() {
7861        let config = minimal_q4_config();
7862        let model_rs = generate_model_rs(&config).unwrap();
7863
7864        assert!(
7865            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7866            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7867        );
7868        assert!(
7869            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7870            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7871        );
7872        assert!(
7873            model_rs.contains("let o_weight = next_q4_buffer"),
7874            "Q4_0 model should use next_q4_buffer for O weight"
7875        );
7876        assert!(
7877            model_rs.contains("let down_weight = next_q4_buffer"),
7878            "Q4_0 model should use next_q4_buffer for down weight"
7879        );
7880    }
7881
7882    #[test]
7883    fn attention_flash_batch_kernel_exists() {
7884        // The flash kernel is still wired into the library (pipeline, kernel
7885        // source).  Dispatch is currently routed to the legacy path pending a
7886        // fix for a numerical issue discovered after the prompt-chunking fix.
7887        let config = minimal_config();
7888        let model_rs = generate_model_rs(&config).unwrap();
7889        let shaders = generate_metal_shaders(&config);
7890
7891        assert!(
7892            shaders.contains("kernel void attention_flash_batch"),
7893            "shaders.metal must still contain the attention_flash_batch kernel"
7894        );
7895        assert!(
7896            shaders.contains("FLASH_K_TILE"),
7897            "flash kernel must tile K/V with a FLASH_K_TILE constant"
7898        );
7899        assert!(
7900            model_rs.contains("attention_flash_batch_pipeline"),
7901            "MetalModel must register the flash attention pipeline"
7902        );
7903    }
7904
7905    #[test]
7906    fn attention_mma_flash_batch_kernel_wired() {
7907        // MMA-accelerated flash attention (issue #212).  Default-on in v0.7.0
7908        // when HEAD_DIM ≤ 128 and num_tokens ≥ 8.  FORGE_MMA_ATTN=0 opts out.
7909        let config = minimal_config();
7910        let model_rs = generate_model_rs(&config).unwrap();
7911        let shaders = generate_metal_shaders(&config);
7912
7913        assert!(
7914            shaders.contains("kernel void attention_mma_flash_batch"),
7915            "shaders.metal must contain the MMA flash kernel"
7916        );
7917        assert!(
7918            shaders.contains("FLASH_MMA_Q_BLOCK"),
7919            "MMA flash kernel must define Q_BLOCK tiling constant"
7920        );
7921        assert!(
7922            shaders.contains("simdgroup_multiply_accumulate"),
7923            "MMA flash kernel must use hardware MMA"
7924        );
7925        assert!(
7926            model_rs.contains("attention_mma_flash_batch_pipeline"),
7927            "MetalModel must register the MMA flash pipeline"
7928        );
7929        assert!(
7930            model_rs.contains("mma_opt_out"),
7931            "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
7932        );
7933        assert!(
7934            model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
7935            "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
7936        );
7937    }
7938
7939    #[test]
7940    fn forward_prefill_batch_chunks_by_max_batch_size() {
7941        // Regression: prior to v0.6.4 forward_prefill_batch truncated prompts
7942        // longer than MAX_BATCH_SIZE (512) tokens via `.min(MAX_BATCH_SIZE)`,
7943        // silently dropping the middle of long prompts.  Must now loop over
7944        // MAX_BATCH_SIZE-sized chunks and carry KV-cache state across them.
7945        let config = minimal_config();
7946        let model_rs = generate_model_rs(&config).unwrap();
7947        assert!(
7948            model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
7949            "forward_prefill_batch must chunk long prompts"
7950        );
7951        assert!(
7952            !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
7953            "the old truncation path must be gone"
7954        );
7955    }
7956
7957    #[test]
7958    fn qwen2_qkv_bias_wired_through_metal_codegen() {
7959        // Issue #210: the pre-v0.6.2 Metal codegen had zero handling for
7960        // qkv_bias.  Verify that a Qwen2-style config emits the bias buffer,
7961        // loader, pipeline, and dispatch call in the expected places.
7962        let config = ModelConfig {
7963            architecture: Architecture::Qwen2,
7964            qkv_bias: true,
7965            ..minimal_config()
7966        };
7967        let model_rs = generate_model_rs(&config).unwrap();
7968
7969        assert!(
7970            model_rs.contains("qkv_bias: Buffer"),
7971            "Qwen2 LayerBuffers must declare qkv_bias field"
7972        );
7973        assert!(
7974            model_rs.contains("let qkv_bias = next_f32_buffer"),
7975            "Qwen2 layer init must load the bias from the weight blob"
7976        );
7977        assert!(
7978            model_rs.contains("add_bias_batch_pipeline"),
7979            "Qwen2 model struct must include the add_bias_batch_pipeline"
7980        );
7981        assert!(
7982            model_rs.contains("fn dispatch_add_bias_batch"),
7983            "Qwen2 codegen must emit dispatch_add_bias_batch helper"
7984        );
7985        assert!(
7986            model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
7987            "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
7988        );
7989        assert!(
7990            model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
7991            "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
7992        );
7993
7994        // The add_bias_batch MSL kernel must be in the shader source.
7995        let shaders = generate_metal_shaders(&config);
7996        assert!(
7997            shaders.contains("kernel void add_bias_batch"),
7998            "shaders.metal must contain the add_bias_batch kernel"
7999        );
8000    }
8001
8002    #[test]
8003    fn llama_does_not_emit_qkv_bias_machinery() {
8004        // Negative test: non-Qwen2 models must NOT carry the bias dispatch,
8005        // buffer, or pipeline — keeps generated code lean for Llama/Phi/etc.
8006        let config = minimal_config();
8007        assert!(!config.qkv_bias);
8008        let model_rs = generate_model_rs(&config).unwrap();
8009        assert!(
8010            !model_rs.contains("qkv_bias: Buffer"),
8011            "Llama must not have qkv_bias field"
8012        );
8013        assert!(
8014            !model_rs.contains("add_bias_batch_pipeline"),
8015            "Llama must not pull in add_bias_batch_pipeline"
8016        );
8017        assert!(
8018            !model_rs.contains("dispatch_add_bias_batch"),
8019            "Llama must not call dispatch_add_bias_batch"
8020        );
8021    }
8022
8023    #[test]
8024    fn q4_dispatch_helper_takes_compute_encoder_ref() {
8025        let config = minimal_q4_config();
8026        let model_rs = generate_model_rs(&config).unwrap();
8027
8028        assert!(
8029            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8030            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8031        );
8032    }
8033
8034    #[test]
8035    fn f32_model_does_not_use_q4_dispatch() {
8036        let config = minimal_config();
8037        let model_rs = generate_model_rs(&config).unwrap();
8038
8039        let forward_start = model_rs
8040            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8041            .unwrap();
8042        let forward_body = &model_rs[forward_start..];
8043        let forward_end = forward_body
8044            .find("\n    fn dispatch_")
8045            .unwrap_or(forward_body.len());
8046        let forward_code = &forward_body[..forward_end];
8047
8048        assert!(
8049            !forward_code.contains("dispatch_matmul_q4"),
8050            "f32 model forward should not use dispatch_matmul_q4"
8051        );
8052    }
8053
8054    #[test]
8055    fn q4_model_lm_head_uses_q4_buffer() {
8056        let config = minimal_q4_config();
8057        let model_rs = generate_model_rs(&config).unwrap();
8058
8059        assert!(
8060            model_rs.contains("let lm_head_buf = next_q4_buffer"),
8061            "Q4_0 model should use next_q4_buffer for lm_head"
8062        );
8063    }
8064
8065    #[test]
8066    fn vec_tile_size_matches_model_dimensions() {
8067        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
8068        let small = minimal_config();
8069        let shaders_small = generate_metal_shaders(&small);
8070        assert!(
8071            shaders_small.contains("vec_tile[128]"),
8072            "vec_tile should be sized to max(hidden, intermediate) = 128"
8073        );
8074
8075        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
8076        let mut large = minimal_config();
8077        large.hidden_size = 2048;
8078        large.intermediate_size = 8192;
8079        let shaders_large = generate_metal_shaders(&large);
8080        assert!(
8081            shaders_large.contains("vec_tile[8192]"),
8082            "vec_tile should be 8192 for models with intermediate=8192"
8083        );
8084        assert!(
8085            !shaders_large.contains("vec_tile[4096]"),
8086            "vec_tile should NOT be hardcoded to 4096"
8087        );
8088    }
8089
8090    #[test]
8091    fn generated_cargo_toml_has_server_deps() {
8092        let toml = generate_cargo_toml("my-model");
8093        assert!(
8094            toml.contains("tiny_http"),
8095            "Cargo.toml should depend on tiny_http"
8096        );
8097        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8098        assert!(
8099            toml.contains("serde_json"),
8100            "Cargo.toml should depend on serde_json"
8101        );
8102    }
8103
8104    #[test]
8105    fn generated_main_rs_has_serve_mode() {
8106        let config = minimal_config();
8107        let main_rs = generate_main_rs("test-model", &config).unwrap();
8108
8109        assert!(
8110            main_rs.contains("--serve"),
8111            "main.rs should parse --serve flag"
8112        );
8113        assert!(
8114            main_rs.contains("--port"),
8115            "main.rs should parse --port flag"
8116        );
8117        assert!(
8118            main_rs.contains("fn serve("),
8119            "main.rs should define serve function"
8120        );
8121        assert!(
8122            main_rs.contains("tiny_http::Server::http"),
8123            "main.rs should create tiny_http server"
8124        );
8125    }
8126
8127    #[test]
8128    fn generated_main_rs_has_chat_completions_endpoint() {
8129        let config = minimal_config();
8130        let main_rs = generate_main_rs("test-model", &config).unwrap();
8131
8132        assert!(
8133            main_rs.contains("/v1/chat/completions"),
8134            "main.rs should handle /v1/chat/completions endpoint"
8135        );
8136        assert!(
8137            main_rs.contains("/v1/models"),
8138            "main.rs should handle /v1/models endpoint"
8139        );
8140        assert!(
8141            main_rs.contains("/health"),
8142            "main.rs should handle /health endpoint"
8143        );
8144    }
8145
8146    #[test]
8147    fn generated_main_rs_has_sse_streaming() {
8148        let config = minimal_config();
8149        let main_rs = generate_main_rs("test-model", &config).unwrap();
8150
8151        assert!(
8152            main_rs.contains("text/event-stream"),
8153            "main.rs should set SSE content type for streaming"
8154        );
8155        assert!(
8156            main_rs.contains("chat.completion.chunk"),
8157            "main.rs should emit SSE chunks"
8158        );
8159        assert!(
8160            main_rs.contains("[DONE]"),
8161            "main.rs should emit [DONE] sentinel"
8162        );
8163    }
8164
8165    #[test]
8166    fn generated_main_rs_has_chat_message_formatting() {
8167        let config = minimal_config();
8168        let main_rs = generate_main_rs("test-model", &config).unwrap();
8169
8170        assert!(
8171            main_rs.contains("fn format_chat_messages"),
8172            "main.rs should define format_chat_messages function"
8173        );
8174        assert!(
8175            main_rs.contains("<|im_start|>"),
8176            "main.rs should use ChatML format"
8177        );
8178        assert!(
8179            main_rs.contains("<|im_end|>"),
8180            "main.rs should use ChatML format"
8181        );
8182    }
8183
8184    #[test]
8185    fn generated_main_rs_has_request_types() {
8186        let config = minimal_config();
8187        let main_rs = generate_main_rs("test-model", &config).unwrap();
8188
8189        assert!(
8190            main_rs.contains("struct ChatRequest"),
8191            "main.rs should define ChatRequest struct"
8192        );
8193        assert!(
8194            main_rs.contains("struct ChatMessage"),
8195            "main.rs should define ChatMessage struct"
8196        );
8197        assert!(
8198            main_rs.contains("Deserialize"),
8199            "main.rs should derive Deserialize for request types"
8200        );
8201    }
8202
8203    #[test]
8204    fn generated_model_has_reset_method() {
8205        let config = minimal_config();
8206        let model_rs = generate_model_rs(&config).unwrap();
8207
8208        assert!(
8209            model_rs.contains("pub fn reset(&mut self)"),
8210            "model.rs should have a reset() method for multi-request serving"
8211        );
8212        assert!(
8213            model_rs.contains("self.pos = 0"),
8214            "reset() should reset position to 0"
8215        );
8216    }
8217
8218    #[test]
8219    fn generated_main_rs_cli_mode_still_works() {
8220        let config = minimal_config();
8221        let main_rs = generate_main_rs("test-model", &config).unwrap();
8222
8223        // CLI mode should still be functional
8224        assert!(
8225            main_rs.contains("fn cli_mode("),
8226            "main.rs should define cli_mode function"
8227        );
8228        assert!(
8229            main_rs.contains("model.forward"),
8230            "main.rs should call model.forward"
8231        );
8232        assert!(
8233            main_rs.contains("model.forward_prefill"),
8234            "main.rs should call model.forward_prefill"
8235        );
8236    }
8237
8238    // ── Batched prefill tests ──────────────────────────────────────────
8239
8240    #[test]
8241    fn generated_shaders_contain_batch_kernels() {
8242        let shaders = generate_metal_shaders(&minimal_config());
8243
8244        assert!(
8245            shaders.contains("kernel void matmul_vec_batch"),
8246            "shaders should contain matmul_vec_batch kernel"
8247        );
8248        assert!(
8249            shaders.contains("kernel void matmul_vec_q8_batch"),
8250            "shaders should contain matmul_vec_q8_batch kernel"
8251        );
8252        assert!(
8253            shaders.contains("kernel void matmul_q8_gemm_batch"),
8254            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8255        );
8256        assert!(
8257            shaders.contains("kernel void matmul_vec_q4_batch"),
8258            "shaders should contain matmul_vec_q4_batch kernel"
8259        );
8260        assert!(
8261            shaders.contains("kernel void rms_norm_batch"),
8262            "shaders should contain rms_norm_batch kernel"
8263        );
8264        assert!(
8265            shaders.contains("kernel void silu_mul_fused_batch"),
8266            "shaders should contain silu_mul_fused_batch kernel"
8267        );
8268        assert!(
8269            shaders.contains("kernel void add_inplace_batch"),
8270            "shaders should contain add_inplace_batch kernel"
8271        );
8272        assert!(
8273            shaders.contains("kernel void copy_embedding_batch"),
8274            "shaders should contain copy_embedding_batch kernel"
8275        );
8276    }
8277
8278    #[test]
8279    fn generated_model_has_batch_pipelines() {
8280        let config = minimal_config();
8281        let model_rs = generate_model_rs(&config).unwrap();
8282
8283        for pipeline in &[
8284            "matmul_batch_pipeline",
8285            "matmul_q8_batch_pipeline",
8286            "matmul_q8_gemm_batch_pipeline",
8287            "matmul_q4_batch_pipeline",
8288            "rms_norm_batch_pipeline",
8289            "rope_batch_pipeline",
8290            "silu_mul_fused_batch_pipeline",
8291            "add_inplace_batch_pipeline",
8292            "copy_embedding_batch_pipeline",
8293            "attention_batch_pipeline",
8294            "copy_kv_batch_pipeline",
8295        ] {
8296            assert!(
8297                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8298                "MetalModel should have {pipeline} field"
8299            );
8300        }
8301    }
8302
8303    #[test]
8304    fn generated_model_has_batch_buffers() {
8305        let config = minimal_config();
8306        let model_rs = generate_model_rs(&config).unwrap();
8307
8308        for buf in &[
8309            "batch_hidden_buf",
8310            "batch_residual_buf",
8311            "batch_qkv_buf",
8312            "batch_attn_out_buf",
8313            "batch_attn_proj_buf",
8314            "batch_gate_up_buf",
8315            "batch_ffn_hidden_buf",
8316            "batch_ffn_out_buf",
8317            "batch_tokens_buf",
8318            "batch_positions_buf",
8319        ] {
8320            assert!(
8321                model_rs.contains(&format!("{buf}: Buffer")),
8322                "MetalModel should have {buf} field"
8323            );
8324        }
8325    }
8326
8327    #[test]
8328    fn generated_model_has_forward_prefill_batch() {
8329        let config = minimal_config();
8330        let model_rs = generate_model_rs(&config).unwrap();
8331
8332        assert!(
8333            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8334            "MetalModel should have forward_prefill_batch method"
8335        );
8336
8337        // forward_prefill should delegate to forward_prefill_batch
8338        assert!(
8339            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8340            "forward_prefill should delegate to forward_prefill_batch"
8341        );
8342    }
8343
8344    #[test]
8345    fn generated_model_has_max_batch_size_constant() {
8346        let config = minimal_config();
8347        let model_rs = generate_model_rs(&config).unwrap();
8348
8349        assert!(
8350            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8351            "model.rs should define MAX_BATCH_SIZE constant"
8352        );
8353    }
8354
8355    #[test]
8356    fn forward_prefill_batch_uses_batch_dispatch() {
8357        let config = minimal_config();
8358        let model_rs = generate_model_rs(&config).unwrap();
8359
8360        let batch_start = model_rs
8361            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8362            .unwrap();
8363        let batch_body = &model_rs[batch_start..];
8364        let batch_end = batch_body
8365            .find("\n    pub fn reset")
8366            .unwrap_or(batch_body.len());
8367        let batch_code = &batch_body[..batch_end];
8368
8369        // Should use batched dispatch methods
8370        assert!(
8371            batch_code.contains("dispatch_rms_norm_batch"),
8372            "forward_prefill_batch should use dispatch_rms_norm_batch"
8373        );
8374        assert!(
8375            batch_code.contains("dispatch_copy_embedding_batch"),
8376            "forward_prefill_batch should use dispatch_copy_embedding_batch"
8377        );
8378        assert!(
8379            batch_code.contains("dispatch_silu_mul_fused_batch"),
8380            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8381        );
8382        // Should use batched causal attention dispatch
8383        assert!(
8384            batch_code.contains("dispatch_attention_batch"),
8385            "forward_prefill_batch should use dispatch_attention_batch"
8386        );
8387        // Should use fused KV cache copy (both K and V in one dispatch)
8388        assert!(
8389            batch_code.contains("dispatch_copy_kv_both_batch"),
8390            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8391        );
8392        // Should use fused RoPE Q+K dispatch
8393        assert!(
8394            batch_code.contains("dispatch_rope_qk_batch"),
8395            "forward_prefill_batch should use dispatch_rope_qk_batch"
8396        );
8397    }
8398
8399    #[test]
8400    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8401        let config = minimal_q8_config();
8402        let model_rs = generate_model_rs(&config).unwrap();
8403
8404        let batch_start = model_rs
8405            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8406            .unwrap();
8407        let batch_body = &model_rs[batch_start..];
8408        let batch_end = batch_body
8409            .find("\n    pub fn reset")
8410            .unwrap_or(batch_body.len());
8411        let batch_code = &batch_body[..batch_end];
8412
8413        assert!(
8414            batch_code.contains("dispatch_matmul_q8_batch"),
8415            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8416        );
8417    }
8418
8419    #[test]
8420    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8421        let config = minimal_q4_config();
8422        let model_rs = generate_model_rs(&config).unwrap();
8423
8424        let batch_start = model_rs
8425            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8426            .unwrap();
8427        let batch_body = &model_rs[batch_start..];
8428        let batch_end = batch_body
8429            .find("\n    pub fn reset")
8430            .unwrap_or(batch_body.len());
8431        let batch_code = &batch_body[..batch_end];
8432
8433        assert!(
8434            batch_code.contains("dispatch_matmul_q4_batch"),
8435            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8436        );
8437    }
8438
8439    #[test]
8440    fn generated_main_rs_uses_batched_prefill() {
8441        let config = minimal_config();
8442        let main_rs = generate_main_rs("test-model", &config).unwrap();
8443
8444        assert!(
8445            main_rs.contains("forward_prefill_batch"),
8446            "main.rs should use forward_prefill_batch for prompt tokens"
8447        );
8448    }
8449
8450    #[test]
8451    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8452        let config = minimal_config();
8453        let model_rs = generate_model_rs(&config).unwrap();
8454
8455        let batch_start = model_rs
8456            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8457            .unwrap();
8458        let batch_body = &model_rs[batch_start..];
8459        let batch_end = batch_body
8460            .find("\n    pub fn reset")
8461            .unwrap_or(batch_body.len());
8462        let batch_code = &batch_body[..batch_end];
8463
8464        assert!(
8465            batch_code.contains("dispatch_matmul_batch"),
8466            "f32 forward_prefill_batch should use dispatch_matmul_batch"
8467        );
8468        // Should NOT use Q8 or Q4 batch dispatch
8469        assert!(
8470            !batch_code.contains("dispatch_matmul_q8_batch"),
8471            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8472        );
8473        assert!(
8474            !batch_code.contains("dispatch_matmul_q4_batch"),
8475            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8476        );
8477    }
8478
8479    #[test]
8480    fn forward_uses_cpu_embedding_lookup() {
8481        let config = minimal_config();
8482        let model_rs = generate_model_rs(&config).unwrap();
8483
8484        // Find just the forward() body (not forward_profile)
8485        let forward_start = model_rs
8486            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8487            .unwrap();
8488        let forward_body = &model_rs[forward_start..];
8489        let forward_end = forward_body
8490            .find("\n    pub fn forward_profile")
8491            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8492            .unwrap_or(forward_body.len());
8493        let forward_code = &forward_body[..forward_end];
8494
8495        // forward() should use CPU memcpy for embedding lookup (unified memory)
8496        assert!(
8497            forward_code.contains("embed_buf.contents()"),
8498            "forward() should access embed_buf via CPU unified memory for embedding lookup"
8499        );
8500        assert!(
8501            forward_code.contains("copy_nonoverlapping"),
8502            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8503        );
8504        // forward() should NOT use GPU dispatch for embedding
8505        assert!(
8506            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8507            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8508        );
8509    }
8510
8511    #[test]
8512    fn forward_profile_method_exists() {
8513        let config = minimal_config();
8514        let model_rs = generate_model_rs(&config).unwrap();
8515
8516        assert!(
8517            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8518            "MetalModel should have forward_profile() method"
8519        );
8520        // Profile method should print timing information
8521        assert!(
8522            model_rs.contains("[profile]"),
8523            "forward_profile() should print timing with [profile] prefix"
8524        );
8525        assert!(
8526            model_rs.contains("d_embed"),
8527            "forward_profile() should measure embedding time"
8528        );
8529        assert!(
8530            model_rs.contains("d_layers"),
8531            "forward_profile() should measure layer time"
8532        );
8533        assert!(
8534            model_rs.contains("d_logits"),
8535            "forward_profile() should measure logits time"
8536        );
8537    }
8538
8539    #[test]
8540    fn generated_cli_has_profile_flag() {
8541        let config = minimal_config();
8542        let main_rs = generate_main_rs("test-model", &config).unwrap();
8543
8544        assert!(
8545            main_rs.contains("--profile"),
8546            "CLI should support --profile flag"
8547        );
8548        assert!(
8549            main_rs.contains("forward_profile"),
8550            "CLI should call forward_profile when --profile is set"
8551        );
8552    }
8553
8554    #[test]
8555    fn generated_cli_has_thermal_yield() {
8556        let config = minimal_config();
8557        let main_rs = generate_main_rs("test-model", &config).unwrap();
8558
8559        assert!(
8560            main_rs.contains("yield_now()"),
8561            "CLI generation loop should include thread::yield_now() for thermal management"
8562        );
8563    }
8564
8565    // ── Real-world validation tests ──────────────────────────────────────
8566
8567    #[test]
8568    fn generated_forward_handles_single_token_prompt() {
8569        // With a single token (the first prompt token), forward() should work
8570        // at pos=0 where seq_len=1. The attention kernel must handle the case
8571        // where there is only one KV entry (no prefill context).
8572        let config = minimal_config();
8573        let model_rs = generate_model_rs(&config).unwrap();
8574
8575        // The forward function should accept any u32 token_id (no minimum pos guard)
8576        let forward_start = model_rs
8577            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8578            .expect("forward() must exist");
8579        let forward_body = &model_rs[forward_start..forward_start + 400];
8580
8581        // Should NOT require pos > 0 or seq_len > 1
8582        assert!(
8583            !forward_body.contains("assert!(self.pos > 0"),
8584            "forward() must accept pos=0 (first token with no prefill)"
8585        );
8586
8587        // The attention kernel should handle seq_len=1 via the pos field
8588        assert!(
8589            model_rs.contains("self.pos"),
8590            "forward() should use self.pos to track sequence position"
8591        );
8592    }
8593
8594    #[test]
8595    fn generated_reset_clears_kv_cache_position() {
8596        // After reset(), the model should be in a clean state. The pos field
8597        // must be 0 so new generation starts from scratch.
8598        let config = minimal_config();
8599        let model_rs = generate_model_rs(&config).unwrap();
8600
8601        let reset_start = model_rs
8602            .find("pub fn reset(&mut self)")
8603            .expect("reset() must exist");
8604        let reset_body = &model_rs[reset_start..reset_start + 200];
8605
8606        // Reset must zero the position counter
8607        assert!(
8608            reset_body.contains("self.pos = 0"),
8609            "reset() must set self.pos = 0"
8610        );
8611
8612        // Verify reset clears prev_cmd (double-buffering state)
8613        assert!(
8614            reset_body.contains("self.prev_cmd = None"),
8615            "reset() should clear prev_cmd for clean command buffer state"
8616        );
8617    }
8618
8619    #[test]
8620    fn generated_serve_handles_empty_messages_gracefully() {
8621        // The serve endpoint should not crash when receiving an empty messages array.
8622        // The format_chat_messages function should handle this gracefully.
8623        let config = minimal_config();
8624        let main_rs = generate_main_rs("test-model", &config).unwrap();
8625
8626        // The format_chat_messages function should exist and handle empty input
8627        let format_fn_start = main_rs
8628            .find("fn format_chat_messages")
8629            .expect("format_chat_messages must exist");
8630        let format_fn_body =
8631            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8632
8633        // It should iterate over messages (an empty slice produces an empty loop)
8634        assert!(
8635            format_fn_body.contains("for msg in messages"),
8636            "format_chat_messages should iterate over the messages slice"
8637        );
8638        // It should always append the assistant prompt suffix
8639        assert!(
8640            format_fn_body.contains("<|im_start|>assistant"),
8641            "format_chat_messages should always append assistant prompt header"
8642        );
8643
8644        // The serve function should call model.reset() before each request
8645        let serve_fn_start = main_rs
8646            .find("fn serve(")
8647            .expect("serve function must exist");
8648        let serve_fn_body = &main_rs[serve_fn_start..];
8649        assert!(
8650            serve_fn_body.contains("model.reset()"),
8651            "serve function should reset model between requests"
8652        );
8653    }
8654
8655    #[test]
8656    fn generated_model_forward_increments_pos() {
8657        // Each forward() call must increment self.pos so the next token
8658        // uses the correct RoPE position and KV cache offset.
8659        let config = minimal_config();
8660        let model_rs = generate_model_rs(&config).unwrap();
8661
8662        let forward_start = model_rs
8663            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8664            .unwrap();
8665        let forward_body = &model_rs[forward_start..];
8666        let forward_end = forward_body
8667            .find("\n    pub fn forward_profile")
8668            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8669            .or_else(|| forward_body.find("\n    fn dispatch_"))
8670            .unwrap_or(forward_body.len());
8671        let forward_code = &forward_body[..forward_end];
8672
8673        assert!(
8674            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8675            "forward() must increment self.pos after processing a token"
8676        );
8677    }
8678}