Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    // The attention-scores shared array must be at least effective_seq_len
122    // elements; anything smaller silently overflows for long prompts.  For
123    // small-context models (135M at 2K), a smaller array saves threadgroup
124    // memory and improves occupancy — so we size it precisely.
125    let attn_scores_size = config.max_seq_len.min(4096);
126    r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154    device const float* matrix [[buffer(0)]],
155    device const float* vector [[buffer(1)]],
156    device float* output       [[buffer(2)]],
157    constant uint& rows        [[buffer(3)]],
158    constant uint& cols        [[buffer(4)]],
159    uint tgid [[threadgroup_position_in_grid]],
160    uint tid [[thread_index_in_threadgroup]],
161    uint simd_lane [[thread_index_in_simdgroup]],
162    uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164    // Cooperatively load vector into threadgroup shared memory
165    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166    for (uint i = tid; i < cols; i += 256) {
167        vec_tile[i] = vector[i];
168    }
169    threadgroup_barrier(mem_flags::mem_threadgroup);
170
171    // Each simdgroup handles 8 consecutive rows
172    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173    if (row_base >= rows) return;
174
175    uint base0 = row_base * cols;
176    uint base1 = (row_base + 1) * cols;
177    uint base2 = (row_base + 2) * cols;
178    uint base3 = (row_base + 3) * cols;
179    uint base4 = (row_base + 4) * cols;
180    uint base5 = (row_base + 5) * cols;
181    uint base6 = (row_base + 6) * cols;
182    uint base7 = (row_base + 7) * cols;
183
184    // float4 vectorized accumulation across 8 rows
185    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
186    float4 sum4_0 = float4(0.0f);
187    float4 sum4_1 = float4(0.0f);
188    float4 sum4_2 = float4(0.0f);
189    float4 sum4_3 = float4(0.0f);
190    float4 sum4_4 = float4(0.0f);
191    float4 sum4_5 = float4(0.0f);
192    float4 sum4_6 = float4(0.0f);
193    float4 sum4_7 = float4(0.0f);
194
195    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205    }
206
207    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216    // Handle remaining elements (cols not divisible by 128)
217    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218        float vv = vec_tile[j];
219        sum0 += matrix[base0 + j] * vv;
220        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227    }
228
229    // Simdgroup hardware warp-level reduction
230    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235    // Only first lane writes the results
236    if (simd_lane == 0) {
237        if (row_base     < rows) output[row_base]     = sum0;
238        if (row_base + 1 < rows) output[row_base + 1] = sum1;
239        if (row_base + 2 < rows) output[row_base + 2] = sum2;
240        if (row_base + 3 < rows) output[row_base + 3] = sum3;
241        if (row_base + 4 < rows) output[row_base + 4] = sum4;
242        if (row_base + 5 < rows) output[row_base + 5] = sum5;
243        if (row_base + 6 < rows) output[row_base + 6] = sum6;
244        if (row_base + 7 < rows) output[row_base + 7] = sum7;
245    }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253    device const float* input   [[buffer(0)]],
254    device const float* weight  [[buffer(1)]],
255    device float* output        [[buffer(2)]],
256    constant uint& n            [[buffer(3)]],
257    constant float& eps         [[buffer(4)]],
258    uint tid [[thread_index_in_threadgroup]])
259{
260    // Each thread accumulates partial sum-of-squares
261    float sum_sq = 0.0f;
262    for (uint i = tid; i < n; i += 256) {
263        float v = input[i];
264        sum_sq += v * v;
265    }
266
267    // Simdgroup-level reduction (hardware warp sum)
268    sum_sq = simd_sum(sum_sq);
269
270    // Cross-simdgroup reduction via shared memory
271    threadgroup float shared[8];
272    uint simd_id = tid / 32;
273    uint simd_lane = tid % 32;
274    if (simd_lane == 0) {
275        shared[simd_id] = sum_sq;
276    }
277    threadgroup_barrier(mem_flags::mem_threadgroup);
278
279    // First thread computes final inverse RMS
280    if (tid == 0) {
281        float total = 0.0f;
282        for (uint i = 0; i < 8; i++) {
283            total += shared[i];
284        }
285        shared[0] = fast::rsqrt(total / float(n) + eps);
286    }
287    threadgroup_barrier(mem_flags::mem_threadgroup);
288
289    float inv_rms = shared[0];
290
291    // Normalize
292    for (uint i = tid; i < n; i += 256) {
293        output[i] = input[i] * inv_rms * weight[i];
294    }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301    device float* data        [[buffer(0)]],
302    constant uint& num_heads  [[buffer(1)]],
303    constant uint& head_dim   [[buffer(2)]],
304    constant uint& pos        [[buffer(3)]],
305    constant float& theta     [[buffer(4)]],
306    uint id [[thread_position_in_grid]])
307{
308    uint half_dim = head_dim / 2;
309    uint total_pairs = num_heads * half_dim;
310    if (id >= total_pairs) return;
311
312    uint h = id / half_dim;
313    uint i = id % half_dim;
314    uint off = h * head_dim;
315
316    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317    float angle = float(pos) * freq;
318    float c = cos(angle);
319    float s = sin(angle);
320
321    float x0 = data[off + 2 * i];
322    float x1 = data[off + 2 * i + 1];
323    data[off + 2 * i]     = x0 * c - x1 * s;
324    data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331    device float* data       [[buffer(0)]],
332    constant uint& n         [[buffer(1)]],
333    uint tid [[thread_index_in_threadgroup]],
334    uint tg_size [[threads_per_threadgroup]])
335{
336    threadgroup float shared_val[256];
337
338    // Pass 1: find max
339    float local_max = -INFINITY;
340    for (uint i = tid; i < n; i += tg_size) {
341        local_max = max(local_max, data[i]);
342    }
343    shared_val[tid] = local_max;
344    threadgroup_barrier(mem_flags::mem_threadgroup);
345
346    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347        if (tid < stride) {
348            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349        }
350        threadgroup_barrier(mem_flags::mem_threadgroup);
351    }
352    float max_val = shared_val[0];
353    threadgroup_barrier(mem_flags::mem_threadgroup);
354
355    // Pass 2: exp and sum
356    float local_sum = 0.0f;
357    for (uint i = tid; i < n; i += tg_size) {
358        float e = fast::exp(data[i] - max_val);
359        data[i] = e;
360        local_sum += e;
361    }
362    shared_val[tid] = local_sum;
363    threadgroup_barrier(mem_flags::mem_threadgroup);
364
365    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366        if (tid < stride) {
367            shared_val[tid] += shared_val[tid + stride];
368        }
369        threadgroup_barrier(mem_flags::mem_threadgroup);
370    }
371    float inv_sum = 1.0f / shared_val[0];
372    threadgroup_barrier(mem_flags::mem_threadgroup);
373
374    // Pass 3: normalize
375    for (uint i = tid; i < n; i += tg_size) {
376        data[i] *= inv_sum;
377    }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384    device const float* gate [[buffer(0)]],
385    device const float* up   [[buffer(1)]],
386    device float* output     [[buffer(2)]],
387    constant uint& n         [[buffer(3)]],
388    uint id [[thread_position_in_grid]])
389{
390    if (id >= n) return;
391    float g = gate[id];
392    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397//   gate = gate_up[0..n], up = gate_up[n..2*n]
398//   output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400    device const float* gate_up [[buffer(0)]],
401    device float* output        [[buffer(1)]],
402    constant uint& n            [[buffer(2)]],
403    uint id [[thread_position_in_grid]])
404{
405    if (id >= n) return;
406    float g = gate_up[id];
407    float u = gate_up[n + id];
408    output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414    device const float* a  [[buffer(0)]],
415    device const float* b  [[buffer(1)]],
416    device float* output   [[buffer(2)]],
417    constant uint& n       [[buffer(3)]],
418    uint id [[thread_position_in_grid]])
419{
420    if (id >= n) return;
421    output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428    device const float* src [[buffer(0)]],
429    device float* dst       [[buffer(1)]],
430    constant uint& count    [[buffer(2)]],
431    uint id [[thread_position_in_grid]])
432{
433    if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440    device const float* src     [[buffer(0)]],
441    device float* dst           [[buffer(1)]],
442    constant uint& src_offset   [[buffer(2)]],  // in floats
443    constant uint& count        [[buffer(3)]],
444    uint id [[thread_position_in_grid]])
445{
446    if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── add_inplace ─────────────────────────────────────────────────────────
450// In-place residual connection: a[i] += b[i]
451// Avoids a separate blit copy for residual add, reducing encoder overhead.
452kernel void add_inplace(
453    device float* a        [[buffer(0)]],
454    device const float* b  [[buffer(1)]],
455    constant uint& n       [[buffer(2)]],
456    uint id [[thread_position_in_grid]])
457{
458    if (id >= n) return;
459    a[id] += b[id];
460}
461
462// ── matmul_vec_q8 ─────────────────────────────────────────────────────
463// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
464// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
465// Operates directly on quantized weights to halve memory bandwidth vs f32,
466// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
467//
468// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
469// because int8->float conversion doubles register demand.  Fully unrolled
470// inner loop with float4 vector loads from shared memory eliminates loop
471// overhead and enables better instruction scheduling.
472// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
473constant constexpr uint Q8_ROWS_PER_SG = 4;
474constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
475
476// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
477// doubles ALU work, so the same register budget applies).
478constant constexpr uint Q4_ROWS_PER_SG = 4;
479constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
480
481kernel void matmul_vec_q8(
482    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
483    device const float* vector   [[buffer(1)]],  // f32 input
484    device float* output         [[buffer(2)]],
485    constant uint& rows          [[buffer(3)]],
486    constant uint& cols          [[buffer(4)]],  // number of elements per row
487    uint tgid [[threadgroup_position_in_grid]],
488    uint tid [[thread_index_in_threadgroup]],
489    uint simd_lane [[thread_index_in_simdgroup]],
490    uint simd_id [[simdgroup_index_in_threadgroup]])
491{
492    // Load vector into shared memory
493    threadgroup float vec_tile[VEC_TILE_SIZE];
494    for (uint i = tid; i < cols; i += 256) {
495        vec_tile[i] = vector[i];
496    }
497    threadgroup_barrier(mem_flags::mem_threadgroup);
498
499    // Each simdgroup handles 4 consecutive rows (lower register pressure)
500    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
501    if (row_base >= rows) return;
502
503    // Q8_0: each block is 34 bytes for 32 elements
504    uint blocks_per_row = cols / 32;
505    uint row_bytes = blocks_per_row * 34;
506
507    // Pointers to each row's Q8_0 data
508    device const uchar* r0 = matrix + row_base * row_bytes;
509    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
510    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
511    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
512
513    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
514
515    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
516        uint bb = blk * 34;
517        uint vb = blk * 32;
518
519        // Prefetch all 4 scales
520        float sc0 = float(*(device const half*)(r0 + bb));
521        float sc1 = float(*(device const half*)(r1 + bb));
522        float sc2 = float(*(device const half*)(r2 + bb));
523        float sc3 = float(*(device const half*)(r3 + bb));
524
525        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
526        // Q8_0 block layout where the int8 data starts at offset +2 from a
527        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
528        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
529        // reduction in memory transactions. Metal's char16/packed_char16 are
530        // reserved types and packed_*int4 require >=4-byte alignment which
531        // this layout does not provide, so packed_short4 is the widest valid
532        // vectorized load.
533        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
534        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
535        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
536        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
537
538        // Load all 8 float4 vector values for this 32-element block from shared memory
539        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
540        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
541        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
542        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
543        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
544        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
545        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
546        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
547
548        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
549        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
550        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
551            short4 _s = short4(SHORT4); \
552            char2 _a = as_type<char2>(_s.x); \
553            char2 _b = as_type<char2>(_s.y); \
554            char2 _c = as_type<char2>(_s.z); \
555            char2 _d = as_type<char2>(_s.w); \
556            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
557            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
558        }
559
560        float4 f0, f1;
561        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
562
563        // Row 0: 4 short4 loads cover 32 int8 weights
564        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
565        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
566        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
567        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
568
569        // Row 1
570        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
571        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
572        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
573        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
574
575        // Row 2
576        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
577        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
578        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
579        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
580
581        // Row 3
582        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
583        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
584        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
585        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
586
587        #undef Q8_UNPACK8
588
589        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
590    }
591
592    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
593    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
594
595    if (simd_lane == 0) {
596        if (row_base     < rows) output[row_base]     = sum0;
597        if (row_base + 1 < rows) output[row_base + 1] = sum1;
598        if (row_base + 2 < rows) output[row_base + 2] = sum2;
599        if (row_base + 3 < rows) output[row_base + 3] = sum3;
600    }
601}
602
603// ── matmul_vec_q4 ─────────────────────────────────────────────────────
604// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
605// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
606// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
607// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
608//
609// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
610// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
611kernel void matmul_vec_q4(
612    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
613    device const float* vector   [[buffer(1)]],  // f32 input
614    device float* output         [[buffer(2)]],
615    constant uint& rows          [[buffer(3)]],
616    constant uint& cols          [[buffer(4)]],  // number of elements per row
617    uint tgid [[threadgroup_position_in_grid]],
618    uint tid [[thread_index_in_threadgroup]],
619    uint simd_lane [[thread_index_in_simdgroup]],
620    uint simd_id [[simdgroup_index_in_threadgroup]])
621{
622    // Load vector into shared memory
623    threadgroup float vec_tile[VEC_TILE_SIZE];
624    for (uint i = tid; i < cols; i += 256) {
625        vec_tile[i] = vector[i];
626    }
627    threadgroup_barrier(mem_flags::mem_threadgroup);
628
629    // Each simdgroup handles 4 consecutive rows
630    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
631    if (row_base >= rows) return;
632
633    // Q4_0: each block is 18 bytes for 32 elements
634    uint blocks_per_row = cols / 32;
635    uint row_bytes = blocks_per_row * 18;
636
637    // Pointers to each row's Q4_0 data
638    device const uchar* r0 = matrix + row_base * row_bytes;
639    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
640    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
641    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
642
643    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
644
645    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
646        uint bb = blk * 18;
647        uint vb = blk * 32;
648
649        // Prefetch all 4 scales
650        float sc0 = float(*(device const half*)(r0 + bb));
651        float sc1 = float(*(device const half*)(r1 + bb));
652        float sc2 = float(*(device const half*)(r2 + bb));
653        float sc3 = float(*(device const half*)(r3 + bb));
654
655        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
656        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
657        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
658        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
659        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
660
661        // Load 8 float4 vector values for 32 elements from shared memory
662        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
663        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
664        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
665        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
666        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
667        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
668        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
669        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
670        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
671
672        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
673        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
674        float bd0=0, bd1=0, bd2=0, bd3=0;
675        uchar4 b;
676
677        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
678        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
679                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
680                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
681                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
682        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
683                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
684                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
685                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
686        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
687                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
688                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
689                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
690        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
691                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
692                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
693                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
694
695        // Row 1
696        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
697                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
698                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
699                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
700        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
701                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
702                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
703                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
704        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
705                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
706                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
707                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
708        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
709                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
710                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
711                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
712
713        // Row 2
714        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
715                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
716                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
717                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
718        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
719                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
720                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
721                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
722        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
723                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
724                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
725                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
726        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
727                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
728                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
729                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
730
731        // Row 3
732        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
733                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
734                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
735                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
736        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
737                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
738                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
739                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
740        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
741                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
742                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
743                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
744        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
745                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
746                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
747                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
748
749        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
750    }
751
752    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
753    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
754
755    if (simd_lane == 0) {
756        if (row_base     < rows) output[row_base]     = sum0;
757        if (row_base + 1 < rows) output[row_base + 1] = sum1;
758        if (row_base + 2 < rows) output[row_base + 2] = sum2;
759        if (row_base + 3 < rows) output[row_base + 3] = sum3;
760    }
761}
762
763// ── attention ───────────────────────────────────────────────────────────
764// Single-query attention with simdgroup cooperative reductions.
765// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
766// with simd_max/simd_sum reductions, then weighted sum of V.
767// Each threadgroup handles one head with 256 threads (8 simdgroups).
768//
769// Buffers:
770//   q:       [num_heads * head_dim]       current query
771//   k_cache: [max_seq_len * num_kv_heads * head_dim]
772//   v_cache: [max_seq_len * num_kv_heads * head_dim]
773//   output:  [num_heads * head_dim]
774kernel void attention(
775    device const float* q        [[buffer(0)]],
776    device const float* k_cache  [[buffer(1)]],
777    device const float* v_cache  [[buffer(2)]],
778    device float* output         [[buffer(3)]],
779    constant uint& seq_len       [[buffer(4)]],
780    constant uint& num_heads     [[buffer(5)]],
781    constant uint& num_kv_heads  [[buffer(6)]],
782    constant uint& head_dim      [[buffer(7)]],
783    uint tgid [[threadgroup_position_in_grid]],
784    uint tid [[thread_index_in_threadgroup]],
785    uint simd_lane [[thread_index_in_simdgroup]],
786    uint simd_id [[simdgroup_index_in_threadgroup]])
787{
788    uint head = tgid;
789    if (head >= num_heads) return;
790    uint kv_head = head / (num_heads / num_kv_heads);
791
792    uint q_off = head * head_dim;
793
794    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
795    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
796    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
797    // most generation steps have short effective context.
798    threadgroup float scores[ATTN_SCORES_SIZE];  // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
799
800    for (uint s = simd_id; s < seq_len; s += 8) {
801        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
802        float dot = 0.0;
803        for (uint d = simd_lane; d < head_dim; d += 32) {
804            dot += q[q_off + d] * k_cache[k_off + d];
805        }
806        dot = simd_sum(dot);
807        if (simd_lane == 0) {
808            scores[s] = dot * fast::rsqrt(float(head_dim));
809        }
810    }
811    threadgroup_barrier(mem_flags::mem_threadgroup);
812
813    // Step 2: Softmax over scores (cooperative)
814    // Find max
815    float local_max = -INFINITY;
816    for (uint s = tid; s < seq_len; s += 256) {
817        local_max = max(local_max, scores[s]);
818    }
819    local_max = simd_max(local_max);
820    threadgroup float shared_max[8];
821    if (simd_lane == 0) shared_max[simd_id] = local_max;
822    threadgroup_barrier(mem_flags::mem_threadgroup);
823    if (tid == 0) {
824        float m = shared_max[0];
825        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
826        shared_max[0] = m;
827    }
828    threadgroup_barrier(mem_flags::mem_threadgroup);
829    float max_val = shared_max[0];
830
831    // Exp and sum
832    float local_sum = 0.0;
833    for (uint s = tid; s < seq_len; s += 256) {
834        scores[s] = fast::exp(scores[s] - max_val);
835        local_sum += scores[s];
836    }
837    local_sum = simd_sum(local_sum);
838    threadgroup float shared_sum[8];
839    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
840    threadgroup_barrier(mem_flags::mem_threadgroup);
841    if (tid == 0) {
842        float total = 0.0;
843        for (uint i = 0; i < 8; i++) total += shared_sum[i];
844        shared_sum[0] = 1.0 / total;
845    }
846    threadgroup_barrier(mem_flags::mem_threadgroup);
847    float inv_sum = shared_sum[0];
848
849    for (uint s = tid; s < seq_len; s += 256) {
850        scores[s] *= inv_sum;
851    }
852    threadgroup_barrier(mem_flags::mem_threadgroup);
853
854    // Step 3: Weighted sum of V: output = scores · V
855    // Each thread handles a range of head_dim dimensions.
856    // Process 4 sequence positions at a time for better ILP and reduced
857    // loop overhead (float4 score gather, 4 V loads per iteration).
858    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
859    uint v_stride = num_kv_heads * head_dim;
860    for (uint d = tid; d < head_dim; d += 256) {
861        float acc = 0.0;
862        uint v_base = kv_head * head_dim + d;
863        for (uint s = 0; s < seq_len4; s += 4) {
864            float sc0 = scores[s];
865            float sc1 = scores[s + 1];
866            float sc2 = scores[s + 2];
867            float sc3 = scores[s + 3];
868            acc += sc0 * v_cache[s * v_stride + v_base]
869                 + sc1 * v_cache[(s+1) * v_stride + v_base]
870                 + sc2 * v_cache[(s+2) * v_stride + v_base]
871                 + sc3 * v_cache[(s+3) * v_stride + v_base];
872        }
873        for (uint s = seq_len4; s < seq_len; s++) {
874            acc += scores[s] * v_cache[s * v_stride + v_base];
875        }
876        output[q_off + d] = acc;
877    }
878}
879
880// ── Batched prefill kernels ────────────────────────────────────────────
881// These kernels process M input vectors against the same weight matrix
882// in a single dispatch, converting mat-vec into mat-mat for better GPU
883// utilization during prompt prefill.
884
885// ── rms_norm_batch ─────────────────────────────────────────────────────
886// RMS normalization for a batch of vectors.
887// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
888// Grid: M threadgroups (one per token).
889kernel void rms_norm_batch(
890    device const float* input   [[buffer(0)]],  // [M, n]
891    device const float* weight  [[buffer(1)]],  // [n]
892    device float* output        [[buffer(2)]],  // [M, n]
893    constant uint& n            [[buffer(3)]],
894    constant float& eps         [[buffer(4)]],
895    constant uint& num_tokens   [[buffer(5)]],
896    uint tgid [[threadgroup_position_in_grid]],
897    uint tid [[thread_index_in_threadgroup]])
898{
899    if (tgid >= num_tokens) return;
900
901    uint base = tgid * n;
902
903    float sum_sq = 0.0f;
904    for (uint i = tid; i < n; i += 256) {
905        float v = input[base + i];
906        sum_sq += v * v;
907    }
908
909    sum_sq = simd_sum(sum_sq);
910
911    threadgroup float shared[8];
912    uint simd_id = tid / 32;
913    uint simd_lane = tid % 32;
914    if (simd_lane == 0) {
915        shared[simd_id] = sum_sq;
916    }
917    threadgroup_barrier(mem_flags::mem_threadgroup);
918
919    if (tid == 0) {
920        float total = 0.0f;
921        for (uint i = 0; i < 8; i++) {
922            total += shared[i];
923        }
924        shared[0] = fast::rsqrt(total / float(n) + eps);
925    }
926    threadgroup_barrier(mem_flags::mem_threadgroup);
927
928    float inv_rms = shared[0];
929
930    for (uint i = tid; i < n; i += 256) {
931        output[base + i] = input[base + i] * inv_rms * weight[i];
932    }
933}
934
935// ── rope_batch ─────────────────────────────────────────────────────────
936// Rotary Position Embedding for a batch of vectors with different positions.
937// data layout: [M, num_heads * head_dim], positions: [M]
938// Each thread handles one (token, head, pair) combination.
939kernel void rope_batch(
940    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
941    constant uint& num_heads     [[buffer(1)]],
942    constant uint& head_dim      [[buffer(2)]],
943    device const uint* positions  [[buffer(3)]],  // [M] position per token
944    constant float& theta        [[buffer(4)]],
945    constant uint& num_tokens    [[buffer(5)]],
946    uint id [[thread_position_in_grid]])
947{
948    uint half_dim = head_dim / 2;
949    uint pairs_per_token = num_heads * half_dim;
950    uint total = num_tokens * pairs_per_token;
951    if (id >= total) return;
952
953    uint token = id / pairs_per_token;
954    uint rem = id % pairs_per_token;
955    uint h = rem / half_dim;
956    uint i = rem % half_dim;
957    uint off = token * (num_heads * head_dim) + h * head_dim;
958
959    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
960    float angle = float(positions[token]) * freq;
961    float c = cos(angle);
962    float s = sin(angle);
963
964    float x0 = data[off + 2 * i];
965    float x1 = data[off + 2 * i + 1];
966    data[off + 2 * i]     = x0 * c - x1 * s;
967    data[off + 2 * i + 1] = x0 * s + x1 * c;
968}
969
970// ── silu_mul_fused_batch ───────────────────────────────────────────────
971// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
972// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
973kernel void silu_mul_fused_batch(
974    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
975    device float* output        [[buffer(1)]],  // [M, n]
976    constant uint& n            [[buffer(2)]],
977    constant uint& num_tokens   [[buffer(3)]],
978    uint id [[thread_position_in_grid]])
979{
980    uint total = num_tokens * n;
981    if (id >= total) return;
982    uint token = id / n;
983    uint i = id % n;
984    uint gu_base = token * 2 * n;
985    float g = gate_up[gu_base + i];
986    float u = gate_up[gu_base + n + i];
987    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
988}
989
990// ── add_inplace_batch ──────────────────────────────────────────────────
991// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
992kernel void add_inplace_batch(
993    device float* a        [[buffer(0)]],  // [M * n]
994    device const float* b  [[buffer(1)]],  // [M * n]
995    constant uint& total   [[buffer(2)]],  // M * n
996    uint id [[thread_position_in_grid]])
997{
998    if (id >= total) return;
999    a[id] += b[id];
1000}
1001
1002// ── copy_embedding_batch ───────────────────────────────────────────────
1003// Copy M embedding rows from embedding table to a contiguous batch buffer.
1004// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1005kernel void copy_embedding_batch(
1006    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1007    device float* output        [[buffer(1)]],  // [M, dim]
1008    device const uint* tokens   [[buffer(2)]],  // [M]
1009    constant uint& dim          [[buffer(3)]],
1010    constant uint& num_tokens   [[buffer(4)]],
1011    uint id [[thread_position_in_grid]])
1012{
1013    uint total = num_tokens * dim;
1014    if (id >= total) return;
1015    uint token_idx = id / dim;
1016    uint d = id % dim;
1017    output[id] = embed[tokens[token_idx] * dim + d];
1018}
1019
1020// ── matmul_vec_batch ───────────────────────────────────────────────────
1021// Batched matrix-vector multiply: process M input vectors against the same
1022// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1023// Each threadgroup handles one (token, row_group) pair.
1024kernel void matmul_vec_batch(
1025    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1026    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1027    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1028    constant uint& num_tokens   [[buffer(3)]],  // M
1029    constant uint& rows         [[buffer(4)]],
1030    constant uint& cols         [[buffer(5)]],
1031    uint tgid [[threadgroup_position_in_grid]],
1032    uint tid [[thread_index_in_threadgroup]],
1033    uint simd_lane [[thread_index_in_simdgroup]],
1034    uint simd_id [[simdgroup_index_in_threadgroup]])
1035{
1036    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1037    uint token = tgid / row_tgs;
1038    uint tg_in_token = tgid % row_tgs;
1039    if (token >= num_tokens) return;
1040
1041    // Load this token's input vector into shared memory
1042    threadgroup float vec_tile[VEC_TILE_SIZE];
1043    device const float* input = inputs + token * cols;
1044    for (uint i = tid; i < cols; i += 256) {
1045        vec_tile[i] = input[i];
1046    }
1047    threadgroup_barrier(mem_flags::mem_threadgroup);
1048
1049    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1050    if (row_base >= rows) return;
1051
1052    uint base0 = row_base * cols;
1053    uint base1 = (row_base + 1) * cols;
1054    uint base2 = (row_base + 2) * cols;
1055    uint base3 = (row_base + 3) * cols;
1056    uint base4 = (row_base + 4) * cols;
1057    uint base5 = (row_base + 5) * cols;
1058    uint base6 = (row_base + 6) * cols;
1059    uint base7 = (row_base + 7) * cols;
1060
1061    uint cols_vec4 = cols & ~127u;
1062    float4 sum4_0 = float4(0.0f);
1063    float4 sum4_1 = float4(0.0f);
1064    float4 sum4_2 = float4(0.0f);
1065    float4 sum4_3 = float4(0.0f);
1066    float4 sum4_4 = float4(0.0f);
1067    float4 sum4_5 = float4(0.0f);
1068    float4 sum4_6 = float4(0.0f);
1069    float4 sum4_7 = float4(0.0f);
1070
1071    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1072        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1073        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1074        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1075        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1076        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1077        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1078        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1079        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1080        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1081    }
1082
1083    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1084    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1085    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1086    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1087    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1088    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1089    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1090    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1091
1092    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1093        float vv = vec_tile[j];
1094        sum0 += matrix[base0 + j] * vv;
1095        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1096        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1097        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1098        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1099        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1100        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1101        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1102    }
1103
1104    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1105    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1106    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1107    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1108
1109    device float* output = outputs + token * rows;
1110    if (simd_lane == 0) {
1111        if (row_base     < rows) output[row_base]     = sum0;
1112        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1113        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1114        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1115        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1116        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1117        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1118        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1119    }
1120}
1121
1122// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1123// Batched Q8_0 matrix-vector multiply for M input vectors.
1124// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1125kernel void matmul_vec_q8_batch(
1126    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1127    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1128    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1129    constant uint& num_tokens    [[buffer(3)]],  // M
1130    constant uint& rows          [[buffer(4)]],
1131    constant uint& cols          [[buffer(5)]],
1132    uint tgid [[threadgroup_position_in_grid]],
1133    uint tid [[thread_index_in_threadgroup]],
1134    uint simd_lane [[thread_index_in_simdgroup]],
1135    uint simd_id [[simdgroup_index_in_threadgroup]])
1136{
1137    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1138    uint token = tgid / row_tgs;
1139    uint tg_in_token = tgid % row_tgs;
1140    if (token >= num_tokens) return;
1141
1142    // Load this token's input vector into shared memory
1143    threadgroup float vec_tile[VEC_TILE_SIZE];
1144    device const float* input = inputs + token * cols;
1145    for (uint i = tid; i < cols; i += 256) {
1146        vec_tile[i] = input[i];
1147    }
1148    threadgroup_barrier(mem_flags::mem_threadgroup);
1149
1150    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1151    if (row_base >= rows) return;
1152
1153    uint blocks_per_row = cols / 32;
1154    uint row_bytes = blocks_per_row * 34;
1155
1156    device const uchar* r0 = matrix + row_base * row_bytes;
1157    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1158    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1159    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1160
1161    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1162
1163    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1164        uint bb = blk * 34;
1165        uint vb = blk * 32;
1166
1167        float sc0 = float(*(device const half*)(r0 + bb));
1168        float sc1 = float(*(device const half*)(r1 + bb));
1169        float sc2 = float(*(device const half*)(r2 + bb));
1170        float sc3 = float(*(device const half*)(r3 + bb));
1171
1172        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1173        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1174        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1175        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1176        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1177        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1178
1179        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1180        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1181        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1182        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1183        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1184        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1185        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1186        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1187
1188        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1189            short4 _s = short4(SHORT4); \
1190            char2 _a = as_type<char2>(_s.x); \
1191            char2 _b = as_type<char2>(_s.y); \
1192            char2 _c = as_type<char2>(_s.z); \
1193            char2 _d = as_type<char2>(_s.w); \
1194            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1195            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1196        }
1197
1198        float4 f0, f1;
1199        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1200
1201        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1202        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1203        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1204        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1205
1206        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1207        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1208        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1209        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1210
1211        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1212        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1213        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1214        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1215
1216        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1217        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1218        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1219        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1220
1221        #undef Q8_UNPACK8
1222
1223        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1224    }
1225
1226    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1227    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1228
1229    device float* output = outputs + token * rows;
1230    if (simd_lane == 0) {
1231        if (row_base     < rows) output[row_base]     = sum0;
1232        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1233        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1234        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1235    }
1236}
1237
1238// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1239// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1240// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1241// the Q8_0 weight blocks are fetched once from device memory and reused for
1242// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1243// per-token dispatch).
1244//
1245// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1246// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1247// with simd_sum.  Token vectors are read directly from device memory inside
1248// the block loop (not cached in shared memory) so intermediate_size up to
1249// 8192 fits without spilling threadgroup memory.
1250constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1251
1252kernel void matmul_q8_gemm_batch(
1253    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1254    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1255    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1256    constant uint& num_tokens    [[buffer(3)]],  // M
1257    constant uint& rows          [[buffer(4)]],
1258    constant uint& cols          [[buffer(5)]],
1259    uint2 tgid [[threadgroup_position_in_grid]],
1260    uint tid [[thread_index_in_threadgroup]],
1261    uint simd_lane [[thread_index_in_simdgroup]],
1262    uint simd_id [[simdgroup_index_in_threadgroup]])
1263{
1264    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1265    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1266    if (row_base >= rows || tok_base >= num_tokens) return;
1267
1268    // How many tokens in this tile are valid?
1269    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1270
1271    uint blocks_per_row = cols / 32;
1272    uint row_bytes = blocks_per_row * 34;
1273
1274    device const uchar* r0 = matrix + row_base * row_bytes;
1275    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1276    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1277    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1278
1279    // Accumulators: 4 tokens × 4 rows per simdgroup.
1280    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1281    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1282    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1283    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1284
1285    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1286        uint bb = blk * 34;
1287        uint vb = blk * 32;
1288
1289        // ── Load weight data ONCE per block (reused across all tokens) ──
1290        float sc0 = float(*(device const half*)(r0 + bb));
1291        float sc1 = float(*(device const half*)(r1 + bb));
1292        float sc2 = float(*(device const half*)(r2 + bb));
1293        float sc3 = float(*(device const half*)(r3 + bb));
1294
1295        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1296        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1297        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1298        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1299
1300        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1301            short4 _s = short4(SHORT4); \
1302            char2 _a = as_type<char2>(_s.x); \
1303            char2 _b = as_type<char2>(_s.y); \
1304            char2 _c = as_type<char2>(_s.z); \
1305            char2 _d = as_type<char2>(_s.w); \
1306            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1307            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1308        }
1309
1310        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1311        // registers for the duration of the block and are dotted against
1312        // every token's vector tile.
1313        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1314        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1315        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1316        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1317
1318        Q8_UNPACK8(d0[0], w0_0, w0_1);
1319        Q8_UNPACK8(d0[1], w0_2, w0_3);
1320        Q8_UNPACK8(d0[2], w0_4, w0_5);
1321        Q8_UNPACK8(d0[3], w0_6, w0_7);
1322
1323        Q8_UNPACK8(d1[0], w1_0, w1_1);
1324        Q8_UNPACK8(d1[1], w1_2, w1_3);
1325        Q8_UNPACK8(d1[2], w1_4, w1_5);
1326        Q8_UNPACK8(d1[3], w1_6, w1_7);
1327
1328        Q8_UNPACK8(d2[0], w2_0, w2_1);
1329        Q8_UNPACK8(d2[1], w2_2, w2_3);
1330        Q8_UNPACK8(d2[2], w2_4, w2_5);
1331        Q8_UNPACK8(d2[3], w2_6, w2_7);
1332
1333        Q8_UNPACK8(d3[0], w3_0, w3_1);
1334        Q8_UNPACK8(d3[1], w3_2, w3_3);
1335        Q8_UNPACK8(d3[2], w3_4, w3_5);
1336        Q8_UNPACK8(d3[3], w3_6, w3_7);
1337
1338        #undef Q8_UNPACK8
1339
1340        // ── For each token, read vector and accumulate against shared weights ──
1341        // Token 0 (always valid: tok_count >= 1).
1342        {
1343            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1344            float4 v0 = *(device const float4*)(a0);
1345            float4 v1 = *(device const float4*)(a0 + 4);
1346            float4 v2 = *(device const float4*)(a0 + 8);
1347            float4 v3 = *(device const float4*)(a0 + 12);
1348            float4 v4 = *(device const float4*)(a0 + 16);
1349            float4 v5 = *(device const float4*)(a0 + 20);
1350            float4 v6 = *(device const float4*)(a0 + 24);
1351            float4 v7 = *(device const float4*)(a0 + 28);
1352            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1353                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1354            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1355                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1356            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1357                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1358            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1359                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1360            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1361        }
1362        // Token 1
1363        if (tok_count > 1) {
1364            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1365            float4 v0 = *(device const float4*)(a1);
1366            float4 v1 = *(device const float4*)(a1 + 4);
1367            float4 v2 = *(device const float4*)(a1 + 8);
1368            float4 v3 = *(device const float4*)(a1 + 12);
1369            float4 v4 = *(device const float4*)(a1 + 16);
1370            float4 v5 = *(device const float4*)(a1 + 20);
1371            float4 v6 = *(device const float4*)(a1 + 24);
1372            float4 v7 = *(device const float4*)(a1 + 28);
1373            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1374                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1375            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1376                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1377            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1378                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1379            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1380                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1381            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1382        }
1383        // Token 2
1384        if (tok_count > 2) {
1385            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1386            float4 v0 = *(device const float4*)(a2);
1387            float4 v1 = *(device const float4*)(a2 + 4);
1388            float4 v2 = *(device const float4*)(a2 + 8);
1389            float4 v3 = *(device const float4*)(a2 + 12);
1390            float4 v4 = *(device const float4*)(a2 + 16);
1391            float4 v5 = *(device const float4*)(a2 + 20);
1392            float4 v6 = *(device const float4*)(a2 + 24);
1393            float4 v7 = *(device const float4*)(a2 + 28);
1394            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1395                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1396            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1397                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1398            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1399                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1400            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1401                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1402            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1403        }
1404        // Token 3
1405        if (tok_count > 3) {
1406            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1407            float4 v0 = *(device const float4*)(a3);
1408            float4 v1 = *(device const float4*)(a3 + 4);
1409            float4 v2 = *(device const float4*)(a3 + 8);
1410            float4 v3 = *(device const float4*)(a3 + 12);
1411            float4 v4 = *(device const float4*)(a3 + 16);
1412            float4 v5 = *(device const float4*)(a3 + 20);
1413            float4 v6 = *(device const float4*)(a3 + 24);
1414            float4 v7 = *(device const float4*)(a3 + 28);
1415            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1416                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1417            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1418                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1419            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1420                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1421            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1422                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1423            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1424        }
1425    }
1426
1427    // simdgroup reduction
1428    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1429    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1430    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1431    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1432
1433    if (simd_lane == 0) {
1434        device float* o0 = outputs + (tok_base + 0) * rows;
1435        if (row_base     < rows) o0[row_base]     = s00;
1436        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1437        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1438        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1439
1440        if (tok_count > 1) {
1441            device float* o1 = outputs + (tok_base + 1) * rows;
1442            if (row_base     < rows) o1[row_base]     = s10;
1443            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1444            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1445            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1446        }
1447        if (tok_count > 2) {
1448            device float* o2 = outputs + (tok_base + 2) * rows;
1449            if (row_base     < rows) o2[row_base]     = s20;
1450            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1451            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1452            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1453        }
1454        if (tok_count > 3) {
1455            device float* o3 = outputs + (tok_base + 3) * rows;
1456            if (row_base     < rows) o3[row_base]     = s30;
1457            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1458            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1459            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1460        }
1461    }
1462}
1463
1464// ── matmul_q8_mma ──────────────────────────────────────────────────────
1465// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1466// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1467// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1468// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1469//
1470// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1471// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1472// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1473// dequantized into threadgroup memory once per block and reused by all
1474// simdgroups in the tile.
1475//
1476// Assumptions (verified in the dispatch helper, falls back otherwise):
1477//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1478//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1479//   * num_tokens may be any value; partial row at the tile boundary is handled
1480//     via a scratch copy path.
1481constant constexpr uint MMA_TOK_TILE = 16;
1482constant constexpr uint MMA_ROW_TILE = 16;
1483
1484kernel void matmul_q8_mma(
1485    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1486    device const float* inputs   [[buffer(1)]],  // [M, cols]
1487    device float* outputs        [[buffer(2)]],  // [M, rows]
1488    constant uint& num_tokens    [[buffer(3)]],
1489    constant uint& rows          [[buffer(4)]],
1490    constant uint& cols          [[buffer(5)]],
1491    uint2 tgid [[threadgroup_position_in_grid]],
1492    uint tid [[thread_index_in_threadgroup]],
1493    uint simd_id [[simdgroup_index_in_threadgroup]])
1494{
1495    uint row_base = tgid.x * MMA_ROW_TILE;
1496    uint tok_base = tgid.y * MMA_TOK_TILE;
1497    if (row_base >= rows || tok_base >= num_tokens) return;
1498
1499    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1500    threadgroup float w_tile[MMA_ROW_TILE * 32];
1501    threadgroup float t_tile[MMA_TOK_TILE * 32];
1502
1503    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1504    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1505    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1506
1507    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1508
1509    uint blocks_per_row = cols / 32;
1510    uint row_bytes = blocks_per_row * 34;
1511
1512    for (uint blk = 0; blk < blocks_per_row; blk++) {
1513        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1514        // 512 floats / 128 threads = 4 floats per thread.
1515        {
1516            uint base = tid * 4;
1517            for (uint ii = 0; ii < 4; ii++) {
1518                uint idx = base + ii;
1519                uint r = idx / 32;
1520                uint k = idx % 32;
1521                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1522                float sc = float(*(device const half*)rp);
1523                int ival = int(*(device const int8_t*)(rp + 2 + k));
1524                w_tile[r * 32 + k] = float(ival) * sc;
1525            }
1526        }
1527
1528        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1529        {
1530            uint base = tid * 4;
1531            for (uint ii = 0; ii < 4; ii++) {
1532                uint idx = base + ii;
1533                uint m = idx / 32;
1534                uint k = idx % 32;
1535                uint tok = tok_base + m;
1536                t_tile[m * 32 + k] = (tok < num_tokens)
1537                    ? inputs[tok * cols + blk * 32 + k]
1538                    : 0.0f;
1539            }
1540        }
1541
1542        threadgroup_barrier(mem_flags::mem_threadgroup);
1543
1544        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1545        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1546        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1547        // C[m, r] += A[m, k] * B[k, r]
1548        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1549            simdgroup_matrix<float, 8, 8> A, B;
1550            simdgroup_load(A,
1551                t_tile + sg_tok_base * 32 + k_sub * 8,
1552                32,
1553                ulong2(0, 0),
1554                false);
1555            simdgroup_load(B,
1556                w_tile + sg_row_base * 32 + k_sub * 8,
1557                32,
1558                ulong2(0, 0),
1559                true);
1560            simdgroup_multiply_accumulate(C, A, B, C);
1561        }
1562
1563        threadgroup_barrier(mem_flags::mem_threadgroup);
1564    }
1565
1566    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1567    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1568    uint out_tok = tok_base + sg_tok_base;
1569    uint out_row = row_base + sg_row_base;
1570    bool full_tok = (out_tok + 8 <= num_tokens);
1571    if (full_tok) {
1572        // Fast path: entire 8×8 sub-tile is in-bounds.
1573        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1574    } else if (out_tok < num_tokens) {
1575        // Partial row at the last token tile: stage in per-simdgroup scratch
1576        // and scalar-copy the valid rows.
1577        threadgroup float scratch[4 * 64];
1578        simdgroup_store(C, scratch + simd_id * 64, 8);
1579        simdgroup_barrier(mem_flags::mem_threadgroup);
1580        uint lane = tid % 32;
1581        if (lane == 0) {
1582            uint valid = num_tokens - out_tok;  // 1..7
1583            for (uint m = 0; m < valid; m++) {
1584                device float* dst = outputs + (out_tok + m) * rows + out_row;
1585                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1586                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1587                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1588            }
1589        }
1590    }
1591}
1592
1593// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1594// Larger-tile variant of matmul_q8_mma for long-context prefill.
1595//
1596// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1597// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1598// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1599//
1600//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1601//     output sub-tiles (tok, row):
1602//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1603//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1604//
1605// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1606// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1607// kernel — and halves the number of threadgroups vs the 16×16 tile.
1608//
1609// Assumptions (verified in dispatch helper, fallback otherwise):
1610//   * cols % 32 == 0
1611//   * rows % 32 == 0
1612constant constexpr uint MMA32_TOK_TILE = 32;
1613constant constexpr uint MMA32_ROW_TILE = 32;
1614
1615kernel void matmul_q8_mma32(
1616    device const uchar* matrix   [[buffer(0)]],
1617    device const float* inputs   [[buffer(1)]],
1618    device float* outputs        [[buffer(2)]],
1619    constant uint& num_tokens    [[buffer(3)]],
1620    constant uint& rows          [[buffer(4)]],
1621    constant uint& cols          [[buffer(5)]],
1622    uint2 tgid [[threadgroup_position_in_grid]],
1623    uint tid [[thread_index_in_threadgroup]],
1624    uint simd_id [[simdgroup_index_in_threadgroup]])
1625{
1626    uint row_base = tgid.x * MMA32_ROW_TILE;
1627    uint tok_base = tgid.y * MMA32_TOK_TILE;
1628    if (row_base >= rows || tok_base >= num_tokens) return;
1629
1630    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1631    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1632    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1633
1634    uint sg_tok_idx  = simd_id / 2;      // 0..3
1635    uint sg_row_half = simd_id % 2;      // 0..1
1636    uint sg_tok_base = sg_tok_idx * 8;
1637    uint sg_row_base_a = sg_row_half * 16 + 0;
1638    uint sg_row_base_b = sg_row_half * 16 + 8;
1639
1640    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1641    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1642
1643    uint blocks_per_row = cols / 32;
1644    uint row_bytes = blocks_per_row * 34;
1645
1646    for (uint blk = 0; blk < blocks_per_row; blk++) {
1647        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1648        {
1649            uint base = tid * 4;
1650            for (uint ii = 0; ii < 4; ii++) {
1651                uint idx = base + ii;
1652                uint r = idx / 32;
1653                uint k = idx % 32;
1654                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1655                float sc = float(*(device const half*)rp);
1656                int ival = int(*(device const int8_t*)(rp + 2 + k));
1657                w_tile[r * 32 + k] = float(ival) * sc;
1658            }
1659        }
1660
1661        // Cooperative token tile load.
1662        {
1663            uint base = tid * 4;
1664            for (uint ii = 0; ii < 4; ii++) {
1665                uint idx = base + ii;
1666                uint m = idx / 32;
1667                uint k = idx % 32;
1668                uint tok = tok_base + m;
1669                t_tile[m * 32 + k] = (tok < num_tokens)
1670                    ? inputs[tok * cols + blk * 32 + k]
1671                    : 0.0f;
1672            }
1673        }
1674
1675        threadgroup_barrier(mem_flags::mem_threadgroup);
1676
1677        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1678        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1679            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1680            simdgroup_load(A,
1681                t_tile + sg_tok_base * 32 + k_sub * 8,
1682                32,
1683                ulong2(0, 0),
1684                false);
1685            simdgroup_load(B_a,
1686                w_tile + sg_row_base_a * 32 + k_sub * 8,
1687                32,
1688                ulong2(0, 0),
1689                true);
1690            simdgroup_load(B_b,
1691                w_tile + sg_row_base_b * 32 + k_sub * 8,
1692                32,
1693                ulong2(0, 0),
1694                true);
1695            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1696            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1697        }
1698
1699        threadgroup_barrier(mem_flags::mem_threadgroup);
1700    }
1701
1702    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1703    // (verified in dispatch), so full simdgroup_store is safe for the row
1704    // dimension; only the last token tile may be partial.
1705    uint out_tok = tok_base + sg_tok_base;
1706    uint out_row_a = row_base + sg_row_base_a;
1707    uint out_row_b = row_base + sg_row_base_b;
1708    bool full_tok = (out_tok + 8 <= num_tokens);
1709    if (full_tok) {
1710        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1711        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1712    } else if (out_tok < num_tokens) {
1713        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1714        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1715        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1716        simdgroup_barrier(mem_flags::mem_threadgroup);
1717        uint lane = tid % 32;
1718        if (lane == 0) {
1719            uint valid = num_tokens - out_tok;  // 1..7
1720            for (uint m = 0; m < valid; m++) {
1721                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1722                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1723                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1724                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1725                for (uint j = 0; j < 8; j++) {
1726                    dst_a[j] = src_a[j];
1727                    dst_b[j] = src_b[j];
1728                }
1729            }
1730        }
1731    }
1732}
1733
1734// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1735// FP16 threadgroup-tile variant of matmul_q8_mma32.
1736//
1737// Stores dequantized weights and token inputs as `half` in threadgroup
1738// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1739// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1740// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1741// representation preserves the full quantized dynamic range.  Token
1742// activations stay numerically safe because the subsequent
1743// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1744//
1745// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1746// accumulators each.  Primary win vs mma32 is occupancy at moderate
1747// prefill lengths where the GPU is wave-starved.
1748kernel void matmul_q8_mma32_h(
1749    device const uchar* matrix   [[buffer(0)]],
1750    device const float* inputs   [[buffer(1)]],
1751    device float* outputs        [[buffer(2)]],
1752    constant uint& num_tokens    [[buffer(3)]],
1753    constant uint& rows          [[buffer(4)]],
1754    constant uint& cols          [[buffer(5)]],
1755    uint2 tgid [[threadgroup_position_in_grid]],
1756    uint tid [[thread_index_in_threadgroup]],
1757    uint simd_id [[simdgroup_index_in_threadgroup]])
1758{
1759    uint row_base = tgid.x * MMA32_ROW_TILE;
1760    uint tok_base = tgid.y * MMA32_TOK_TILE;
1761    if (row_base >= rows || tok_base >= num_tokens) return;
1762
1763    // 32×32 half tiles — 2 KB each, 4 KB total.
1764    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1765    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1766
1767    uint sg_tok_idx  = simd_id / 2;
1768    uint sg_row_half = simd_id % 2;
1769    uint sg_tok_base = sg_tok_idx * 8;
1770    uint sg_row_base_a = sg_row_half * 16 + 0;
1771    uint sg_row_base_b = sg_row_half * 16 + 8;
1772
1773    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1774    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1775
1776    uint blocks_per_row = cols / 32;
1777    uint row_bytes = blocks_per_row * 34;
1778
1779    for (uint blk = 0; blk < blocks_per_row; blk++) {
1780        // Cooperative weight dequantization to FP16.
1781        {
1782            uint base = tid * 4;
1783            for (uint ii = 0; ii < 4; ii++) {
1784                uint idx = base + ii;
1785                uint r = idx / 32;
1786                uint k = idx % 32;
1787                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1788                float sc = float(*(device const half*)rp);
1789                int ival = int(*(device const int8_t*)(rp + 2 + k));
1790                w_tile[r * 32 + k] = half(float(ival) * sc);
1791            }
1792        }
1793
1794        // Cooperative token tile load (f32 → f16 narrowing).
1795        {
1796            uint base = tid * 4;
1797            for (uint ii = 0; ii < 4; ii++) {
1798                uint idx = base + ii;
1799                uint m = idx / 32;
1800                uint k = idx % 32;
1801                uint tok = tok_base + m;
1802                t_tile[m * 32 + k] = (tok < num_tokens)
1803                    ? half(inputs[tok * cols + blk * 32 + k])
1804                    : half(0);
1805            }
1806        }
1807
1808        threadgroup_barrier(mem_flags::mem_threadgroup);
1809
1810        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1811            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1812            simdgroup_load(A,
1813                t_tile + sg_tok_base * 32 + k_sub * 8,
1814                32,
1815                ulong2(0, 0),
1816                false);
1817            simdgroup_load(B_a,
1818                w_tile + sg_row_base_a * 32 + k_sub * 8,
1819                32,
1820                ulong2(0, 0),
1821                true);
1822            simdgroup_load(B_b,
1823                w_tile + sg_row_base_b * 32 + k_sub * 8,
1824                32,
1825                ulong2(0, 0),
1826                true);
1827            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1828            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1829        }
1830
1831        threadgroup_barrier(mem_flags::mem_threadgroup);
1832    }
1833
1834    uint out_tok = tok_base + sg_tok_base;
1835    uint out_row_a = row_base + sg_row_base_a;
1836    uint out_row_b = row_base + sg_row_base_b;
1837    bool full_tok = (out_tok + 8 <= num_tokens);
1838    if (full_tok) {
1839        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1840        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1841    } else if (out_tok < num_tokens) {
1842        threadgroup float scratch[8 * 2 * 64];
1843        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1844        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1845        simdgroup_barrier(mem_flags::mem_threadgroup);
1846        uint lane = tid % 32;
1847        if (lane == 0) {
1848            uint valid = num_tokens - out_tok;
1849            for (uint m = 0; m < valid; m++) {
1850                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1851                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1852                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1853                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1854                for (uint j = 0; j < 8; j++) {
1855                    dst_a[j] = src_a[j];
1856                    dst_b[j] = src_b[j];
1857                }
1858            }
1859        }
1860    }
1861}
1862
1863// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1864// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1865//
1866// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1867// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1868//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1869//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1870// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1871//
1872// Per K_sub iteration: load two A fragments and two B fragments, then run
1873// **four** MMA instructions reusing A_top with both B's and A_bot with
1874// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1875// 2-accumulator kernel and halves the thread count per threadgroup (128
1876// threads), which often improves occupancy on Apple GPUs where the
1877// concurrent-thread budget is the tighter limit than shared-memory size.
1878kernel void matmul_q8_mma32_h4(
1879    device const uchar* matrix   [[buffer(0)]],
1880    device const float* inputs   [[buffer(1)]],
1881    device float* outputs        [[buffer(2)]],
1882    constant uint& num_tokens    [[buffer(3)]],
1883    constant uint& rows          [[buffer(4)]],
1884    constant uint& cols          [[buffer(5)]],
1885    uint2 tgid [[threadgroup_position_in_grid]],
1886    uint tid [[thread_index_in_threadgroup]],
1887    uint simd_id [[simdgroup_index_in_threadgroup]])
1888{
1889    uint row_base = tgid.x * MMA32_ROW_TILE;
1890    uint tok_base = tgid.y * MMA32_TOK_TILE;
1891    if (row_base >= rows || tok_base >= num_tokens) return;
1892
1893    // 32×32 FP16 tiles, 4 KB total.
1894    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1895    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1896
1897    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1898    uint sg_tok_q = simd_id / 2;   // 0..1
1899    uint sg_row_q = simd_id % 2;   // 0..1
1900    uint sg_tok_base = sg_tok_q * 16;
1901    uint sg_row_base = sg_row_q * 16;
1902
1903    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1904    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1905    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1906    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1907
1908    uint blocks_per_row = cols / 32;
1909    uint row_bytes = blocks_per_row * 34;
1910
1911    for (uint blk = 0; blk < blocks_per_row; blk++) {
1912        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1913        {
1914            uint base = tid * 8;
1915            for (uint ii = 0; ii < 8; ii++) {
1916                uint idx = base + ii;
1917                uint r = idx / 32;
1918                uint k = idx % 32;
1919                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1920                float sc = float(*(device const half*)rp);
1921                int ival = int(*(device const int8_t*)(rp + 2 + k));
1922                w_tile[r * 32 + k] = half(float(ival) * sc);
1923            }
1924        }
1925
1926        // Cooperative token tile load.
1927        {
1928            uint base = tid * 8;
1929            for (uint ii = 0; ii < 8; ii++) {
1930                uint idx = base + ii;
1931                uint m = idx / 32;
1932                uint k = idx % 32;
1933                uint tok = tok_base + m;
1934                t_tile[m * 32 + k] = (tok < num_tokens)
1935                    ? half(inputs[tok * cols + blk * 32 + k])
1936                    : half(0);
1937            }
1938        }
1939
1940        threadgroup_barrier(mem_flags::mem_threadgroup);
1941
1942        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1943        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1944            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1945            simdgroup_load(A_top,
1946                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1947                32,
1948                ulong2(0, 0),
1949                false);
1950            simdgroup_load(A_bot,
1951                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1952                32,
1953                ulong2(0, 0),
1954                false);
1955            simdgroup_load(B_lo,
1956                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1957                32,
1958                ulong2(0, 0),
1959                true);
1960            simdgroup_load(B_hi,
1961                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1962                32,
1963                ulong2(0, 0),
1964                true);
1965            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1966            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1967            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1968            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1969        }
1970
1971        threadgroup_barrier(mem_flags::mem_threadgroup);
1972    }
1973
1974    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
1975    uint out_tok_top = tok_base + sg_tok_base + 0;
1976    uint out_tok_bot = tok_base + sg_tok_base + 8;
1977    uint out_row_lo  = row_base + sg_row_base + 0;
1978    uint out_row_hi  = row_base + sg_row_base + 8;
1979    bool full = (out_tok_bot + 8 <= num_tokens);
1980    if (full) {
1981        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1982        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1983        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1984        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1985    } else {
1986        // Partial-token fallback via per-simdgroup scratch.
1987        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
1988        uint sg_base = simd_id * 256;
1989        simdgroup_store(C_00, scratch + sg_base +   0, 8);
1990        simdgroup_store(C_01, scratch + sg_base +  64, 8);
1991        simdgroup_store(C_10, scratch + sg_base + 128, 8);
1992        simdgroup_store(C_11, scratch + sg_base + 192, 8);
1993        simdgroup_barrier(mem_flags::mem_threadgroup);
1994        uint lane = tid % 32;
1995        if (lane == 0) {
1996            for (uint m = 0; m < 8; m++) {
1997                uint t_top = out_tok_top + m;
1998                if (t_top < num_tokens) {
1999                    device float* dst0 = outputs + t_top * rows + out_row_lo;
2000                    device float* dst1 = outputs + t_top * rows + out_row_hi;
2001                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
2002                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
2003                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2004                }
2005                uint t_bot = out_tok_bot + m;
2006                if (t_bot < num_tokens) {
2007                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2008                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2009                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2010                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2011                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2012                }
2013            }
2014        }
2015    }
2016}
2017
2018// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2019// All-half MMA variant of matmul_q8_mma32_h4.
2020//
2021// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2022// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2023// rate (dual-issue FMA), so if Q8_0 precision holds through half
2024// accumulation this kernel can double the effective FLOP throughput on
2025// matmul-bound prefill.
2026//
2027// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2028// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2029// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2030// precision on extreme values, but the inputs are already quantized so the
2031// per-product error floor is higher than the half-precision rounding error.
2032// We verify correctness on 135M / 1B / 3B before enabling.
2033kernel void matmul_q8_mma32_hh4(
2034    device const uchar* matrix   [[buffer(0)]],
2035    device const float* inputs   [[buffer(1)]],
2036    device float* outputs        [[buffer(2)]],
2037    constant uint& num_tokens    [[buffer(3)]],
2038    constant uint& rows          [[buffer(4)]],
2039    constant uint& cols          [[buffer(5)]],
2040    uint2 tgid [[threadgroup_position_in_grid]],
2041    uint tid [[thread_index_in_threadgroup]],
2042    uint simd_id [[simdgroup_index_in_threadgroup]])
2043{
2044    uint row_base = tgid.x * MMA32_ROW_TILE;
2045    uint tok_base = tgid.y * MMA32_TOK_TILE;
2046    if (row_base >= rows || tok_base >= num_tokens) return;
2047
2048    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2049    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2050
2051    uint sg_tok_q = simd_id / 2;
2052    uint sg_row_q = simd_id % 2;
2053    uint sg_tok_base = sg_tok_q * 16;
2054    uint sg_row_base = sg_row_q * 16;
2055
2056    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2057    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2058    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2059    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2060
2061    uint blocks_per_row = cols / 32;
2062    uint row_bytes = blocks_per_row * 34;
2063
2064    for (uint blk = 0; blk < blocks_per_row; blk++) {
2065        {
2066            uint base = tid * 8;
2067            for (uint ii = 0; ii < 8; ii++) {
2068                uint idx = base + ii;
2069                uint r = idx / 32;
2070                uint k = idx % 32;
2071                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2072                float sc = float(*(device const half*)rp);
2073                int ival = int(*(device const int8_t*)(rp + 2 + k));
2074                w_tile[r * 32 + k] = half(float(ival) * sc);
2075            }
2076        }
2077        {
2078            uint base = tid * 8;
2079            for (uint ii = 0; ii < 8; ii++) {
2080                uint idx = base + ii;
2081                uint m = idx / 32;
2082                uint k = idx % 32;
2083                uint tok = tok_base + m;
2084                t_tile[m * 32 + k] = (tok < num_tokens)
2085                    ? half(inputs[tok * cols + blk * 32 + k])
2086                    : half(0);
2087            }
2088        }
2089        threadgroup_barrier(mem_flags::mem_threadgroup);
2090
2091        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2092            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2093            simdgroup_load(A_top,
2094                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2095                32, ulong2(0, 0), false);
2096            simdgroup_load(A_bot,
2097                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2098                32, ulong2(0, 0), false);
2099            simdgroup_load(B_lo,
2100                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2101                32, ulong2(0, 0), true);
2102            simdgroup_load(B_hi,
2103                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2104                32, ulong2(0, 0), true);
2105            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2106            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2107            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2108            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2109        }
2110
2111        threadgroup_barrier(mem_flags::mem_threadgroup);
2112    }
2113
2114    // Store half accumulators via scratch (must widen to f32 for device output).
2115    uint out_tok_top = tok_base + sg_tok_base + 0;
2116    uint out_tok_bot = tok_base + sg_tok_base + 8;
2117    uint out_row_lo  = row_base + sg_row_base + 0;
2118    uint out_row_hi  = row_base + sg_row_base + 8;
2119
2120    threadgroup half scratch[4 * 4 * 64];
2121    uint sg_base = simd_id * 256;
2122    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2123    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2124    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2125    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2126    simdgroup_barrier(mem_flags::mem_threadgroup);
2127    uint lane = tid % 32;
2128    if (lane == 0) {
2129        for (uint m = 0; m < 8; m++) {
2130            uint t_top = out_tok_top + m;
2131            if (t_top < num_tokens) {
2132                device float* dst0 = outputs + t_top * rows + out_row_lo;
2133                device float* dst1 = outputs + t_top * rows + out_row_hi;
2134                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2135                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2136                for (uint j = 0; j < 8; j++) {
2137                    dst0[j] = float(src0[j]);
2138                    dst1[j] = float(src1[j]);
2139                }
2140            }
2141            uint t_bot = out_tok_bot + m;
2142            if (t_bot < num_tokens) {
2143                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2144                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2145                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2146                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2147                for (uint j = 0; j < 8; j++) {
2148                    dst2[j] = float(src2[j]);
2149                    dst3[j] = float(src3[j]);
2150                }
2151            }
2152        }
2153    }
2154}
2155
2156// ── add_bias_batch ─────────────────────────────────────────────────────
2157// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2158// Used for Qwen2 QKV bias after the fused qkv matmul.
2159//     out[token, i] += bias[i]    for i in 0..rows, token in 0..num_tokens
2160kernel void add_bias_batch(
2161    device float* out            [[buffer(0)]],  // [num_tokens, rows]
2162    device const float* bias     [[buffer(1)]],  // [rows]
2163    constant uint& num_tokens    [[buffer(2)]],
2164    constant uint& rows          [[buffer(3)]],
2165    uint id [[thread_position_in_grid]])
2166{
2167    uint total = num_tokens * rows;
2168    if (id >= total) return;
2169    uint i = id % rows;
2170    out[id] += bias[i];
2171}
2172
2173// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2174// Batched Q4_0 matrix-vector multiply for M input vectors.
2175// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2176kernel void matmul_vec_q4_batch(
2177    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2178    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2179    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2180    constant uint& num_tokens    [[buffer(3)]],  // M
2181    constant uint& rows          [[buffer(4)]],
2182    constant uint& cols          [[buffer(5)]],
2183    uint tgid [[threadgroup_position_in_grid]],
2184    uint tid [[thread_index_in_threadgroup]],
2185    uint simd_lane [[thread_index_in_simdgroup]],
2186    uint simd_id [[simdgroup_index_in_threadgroup]])
2187{
2188    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2189    uint token = tgid / row_tgs;
2190    uint tg_in_token = tgid % row_tgs;
2191    if (token >= num_tokens) return;
2192
2193    threadgroup float vec_tile[VEC_TILE_SIZE];
2194    device const float* input = inputs + token * cols;
2195    for (uint i = tid; i < cols; i += 256) {
2196        vec_tile[i] = input[i];
2197    }
2198    threadgroup_barrier(mem_flags::mem_threadgroup);
2199
2200    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2201    if (row_base >= rows) return;
2202
2203    uint blocks_per_row = cols / 32;
2204    uint row_bytes = blocks_per_row * 18;
2205
2206    device const uchar* r0 = matrix + row_base * row_bytes;
2207    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2208    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2209    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2210
2211    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2212
2213    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2214        uint bb = blk * 18;
2215        uint vb = blk * 32;
2216
2217        float sc0 = float(*(device const half*)(r0 + bb));
2218        float sc1 = float(*(device const half*)(r1 + bb));
2219        float sc2 = float(*(device const half*)(r2 + bb));
2220        float sc3 = float(*(device const half*)(r3 + bb));
2221
2222        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2223        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2224        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2225        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2226
2227        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2228        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2229        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2230        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2231        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2232        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2233        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2234        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2235
2236        float bd0=0, bd1=0, bd2=0, bd3=0;
2237        uchar4 b;
2238
2239        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2240        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2241        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2242        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2243
2244        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2245        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2246        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2247        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2248
2249        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2250        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2251        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2252        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2253
2254        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2255        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2256        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2257        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2258
2259        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2260    }
2261
2262    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2263    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2264
2265    device float* output = outputs + token * rows;
2266    if (simd_lane == 0) {
2267        if (row_base     < rows) output[row_base]     = sum0;
2268        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2269        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2270        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2271    }
2272}
2273
2274// ── copy_kv_batch ─────────────────────────────────────────────────────
2275// Copy K or V from a strided batch QKV buffer to the KV cache.
2276// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2277// dst layout: contiguous [max_seq, kv_dim] cache.
2278kernel void copy_kv_batch(
2279    device const float* src  [[buffer(0)]],  // batch QKV buffer
2280    device float* dst        [[buffer(1)]],  // KV cache
2281    constant uint& M         [[buffer(2)]],  // num tokens in batch
2282    constant uint& kv_dim    [[buffer(3)]],  // floats per KV vector
2283    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2284    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2285    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2286    uint id [[thread_position_in_grid]])
2287{
2288    uint total = M * kv_dim;
2289    if (id >= total) return;
2290    uint token = id / kv_dim;
2291    uint d = id % kv_dim;
2292    uint dst_off = (base_pos + token) * kv_dim + d;
2293    uint src_off = token * src_stride + src_offset + d;
2294    dst[dst_off] = src[src_off];
2295}
2296
2297// ── attention_batch ───────────────────────────────────────────────────
2298// Batched causal attention for prefill. Processes M tokens in one dispatch.
2299// Each threadgroup handles one (token, head) pair.
2300// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2301// Causal masking: token i can only attend to positions 0..base_pos+i.
2302kernel void attention_batch(
2303    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2304    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2305    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2306    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2307    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2308    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2309    constant uint& num_heads         [[buffer(6)]],
2310    constant uint& num_kv_heads      [[buffer(7)]],
2311    constant uint& head_dim          [[buffer(8)]],
2312    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2313    uint tgid [[threadgroup_position_in_grid]],
2314    uint tid [[thread_index_in_threadgroup]],
2315    uint simd_lane [[thread_index_in_simdgroup]],
2316    uint simd_id [[simdgroup_index_in_threadgroup]])
2317{
2318    // Grid: M * num_heads threadgroups
2319    uint token_idx = tgid / num_heads;
2320    uint head = tgid % num_heads;
2321    if (token_idx >= M) return;
2322
2323    uint kv_head = head / (num_heads / num_kv_heads);
2324    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2325
2326    // Q offset uses strided layout (from batch QKV buffer)
2327    uint q_off = token_idx * q_stride + head * head_dim;
2328    // Output is contiguous [M, num_heads * head_dim]
2329    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2330
2331    // Shared memory for attention scores — sized to the effective max_seq_len
2332    // (4096 for all supported models) so long-context attention doesn't overflow.
2333    threadgroup float scores[ATTN_SCORES_SIZE];
2334
2335    // Step 1: Q * K^T with simdgroup reduction
2336    for (uint s = simd_id; s < seq_len; s += 8) {
2337        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2338        float dot = 0.0;
2339        for (uint d = simd_lane; d < head_dim; d += 32) {
2340            dot += q_batch[q_off + d] * k_cache[k_off + d];
2341        }
2342        dot = simd_sum(dot);
2343        if (simd_lane == 0) {
2344            scores[s] = dot * fast::rsqrt(float(head_dim));
2345        }
2346    }
2347    threadgroup_barrier(mem_flags::mem_threadgroup);
2348
2349    // Step 2: Softmax (cooperative)
2350    float local_max = -INFINITY;
2351    for (uint s = tid; s < seq_len; s += 256) {
2352        local_max = max(local_max, scores[s]);
2353    }
2354    local_max = simd_max(local_max);
2355    threadgroup float shared_max[8];
2356    if (simd_lane == 0) shared_max[simd_id] = local_max;
2357    threadgroup_barrier(mem_flags::mem_threadgroup);
2358    if (tid == 0) {
2359        float m = shared_max[0];
2360        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2361        shared_max[0] = m;
2362    }
2363    threadgroup_barrier(mem_flags::mem_threadgroup);
2364    float max_val = shared_max[0];
2365
2366    float local_sum = 0.0;
2367    for (uint s = tid; s < seq_len; s += 256) {
2368        scores[s] = fast::exp(scores[s] - max_val);
2369        local_sum += scores[s];
2370    }
2371    local_sum = simd_sum(local_sum);
2372    threadgroup float shared_sum[8];
2373    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2374    threadgroup_barrier(mem_flags::mem_threadgroup);
2375    if (tid == 0) {
2376        float total = 0.0;
2377        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2378        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2379    }
2380    threadgroup_barrier(mem_flags::mem_threadgroup);
2381    float inv_sum = shared_sum[0];
2382    for (uint s = tid; s < seq_len; s += 256) {
2383        scores[s] *= inv_sum;
2384    }
2385    threadgroup_barrier(mem_flags::mem_threadgroup);
2386
2387    // Step 3: scores * V using float4 vectorized loads
2388    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2389    // This is much better than the scalar version where only 64 of 256 threads are active.
2390    uint v_stride = num_kv_heads * head_dim;
2391    uint head_dim4 = head_dim / 4;
2392    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2393        uint d = d4 * 4;
2394        float4 acc = float4(0.0);
2395        uint v_base = kv_head * head_dim + d;
2396        uint seq_len4 = seq_len & ~3u;
2397        for (uint s = 0; s < seq_len4; s += 4) {
2398            float sc0 = scores[s];
2399            float sc1 = scores[s + 1];
2400            float sc2 = scores[s + 2];
2401            float sc3 = scores[s + 3];
2402            acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
2403                 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
2404                 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
2405                 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
2406        }
2407        for (uint s = seq_len4; s < seq_len; s++) {
2408            acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
2409        }
2410        *(device float4*)(output_batch + out_off + d) = acc;
2411    }
2412    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2413    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2414        float acc = 0.0;
2415        uint v_base = kv_head * head_dim + d;
2416        for (uint s = 0; s < seq_len; s++) {
2417            acc += scores[s] * v_cache[s * v_stride + v_base];
2418        }
2419        output_batch[out_off + d] = acc;
2420    }
2421}
2422
2423// ── attention_flash_batch ─────────────────────────────────────────────
2424// Streaming attention with online softmax.  Same grid as attention_batch
2425// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2426// matrix is never materialized.  K/V positions are processed in a tile of
2427// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2428// the standard flash-attention recurrence:
2429//
2430//     m_new   = max(m_old, tile_max)
2431//     alpha   = exp(m_old - m_new)
2432//     l_new   = alpha * l_old + sum(exp(S - m_new))
2433//     O_new   = alpha * O_old + sum(exp(S - m_new) * V)
2434//     O_final = O / l
2435//
2436// This removes the `scores[2048]` cap in attention_batch (which silently
2437// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2438// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2439//
2440// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2441constant constexpr uint FLASH_K_TILE = 32;
2442constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2443
2444kernel void attention_flash_batch(
2445    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2446    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2447    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2448    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2449    constant uint& M                 [[buffer(4)]],
2450    constant uint& base_pos          [[buffer(5)]],
2451    constant uint& num_heads         [[buffer(6)]],
2452    constant uint& num_kv_heads      [[buffer(7)]],
2453    constant uint& head_dim          [[buffer(8)]],
2454    constant uint& q_stride          [[buffer(9)]],
2455    uint tgid [[threadgroup_position_in_grid]],
2456    uint tid [[thread_index_in_threadgroup]],
2457    uint simd_lane [[thread_index_in_simdgroup]],
2458    uint simd_id [[simdgroup_index_in_threadgroup]])
2459{
2460    uint token_idx = tgid / num_heads;
2461    uint head = tgid % num_heads;
2462    if (token_idx >= M) return;
2463
2464    uint kv_head = head / (num_heads / num_kv_heads);
2465    uint seq_len = base_pos + token_idx + 1;  // causal: attend to [0, base_pos + token_idx]
2466    uint q_off = token_idx * q_stride + head * head_dim;
2467    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2468
2469    // Threadgroup state:
2470    //   q_sh:      Q vector for this (token, head), loaded once
2471    //   o_sh:      running output vector, updated each K tile
2472    //   scores_sh: scores for the current K tile only
2473    //   stats:     [running max, running sum]  (see flash-attention recurrence)
2474    threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2475    threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2476    threadgroup float scores_sh[FLASH_K_TILE];
2477    threadgroup float stats[2];
2478    threadgroup float sg_scratch[8];  // simdgroup-level reduction buffer
2479
2480    // --- Load Q (one row) and zero the running O ---
2481    for (uint d = tid; d < head_dim; d += 256) {
2482        q_sh[d] = q_batch[q_off + d];
2483        o_sh[d] = 0.0f;
2484    }
2485    if (tid == 0) {
2486        stats[0] = -INFINITY;
2487        stats[1] = 0.0f;
2488    }
2489    threadgroup_barrier(mem_flags::mem_threadgroup);
2490
2491    float scale = fast::rsqrt(float(head_dim));
2492    uint v_stride = num_kv_heads * head_dim;
2493    uint v_base = kv_head * head_dim;
2494
2495    // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2496    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2497        uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2498
2499        // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2500        // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2501        for (uint ti = simd_id; ti < tile_n; ti += 8) {
2502            uint k_off = (kv_base + ti) * v_stride + v_base;  // same layout as V stride
2503            float dot = 0.0f;
2504            for (uint d = simd_lane; d < head_dim; d += 32) {
2505                dot += q_sh[d] * k_cache[k_off + d];
2506            }
2507            dot = simd_sum(dot);
2508            if (simd_lane == 0) {
2509                scores_sh[ti] = dot * scale;
2510            }
2511        }
2512        threadgroup_barrier(mem_flags::mem_threadgroup);
2513
2514        // [2] Tile max via cooperative reduction.
2515        float local_max = -INFINITY;
2516        for (uint s = tid; s < tile_n; s += 256) {
2517            local_max = max(local_max, scores_sh[s]);
2518        }
2519        local_max = simd_max(local_max);
2520        if (simd_lane == 0) {
2521            sg_scratch[simd_id] = local_max;
2522        }
2523        threadgroup_barrier(mem_flags::mem_threadgroup);
2524        // [3] Merge with running max, compute alpha, rescale running l.
2525        float m_new;
2526        float alpha;
2527        if (tid == 0) {
2528            float tile_max = sg_scratch[0];
2529            for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2530            float m_old = stats[0];
2531            m_new = max(m_old, tile_max);
2532            // First iteration: m_old = -inf → alpha = 0 (reset O).
2533            alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2534            stats[0] = m_new;
2535            stats[1] *= alpha;
2536            // Broadcast via sg_scratch.
2537            sg_scratch[0] = alpha;
2538            sg_scratch[1] = m_new;
2539        }
2540        threadgroup_barrier(mem_flags::mem_threadgroup);
2541        alpha = sg_scratch[0];
2542        m_new = sg_scratch[1];
2543
2544        // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2545        for (uint d = tid; d < head_dim; d += 256) {
2546            o_sh[d] *= alpha;
2547        }
2548        for (uint s = tid; s < tile_n; s += 256) {
2549            scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2550        }
2551        threadgroup_barrier(mem_flags::mem_threadgroup);
2552
2553        // [5] Tile sum → update running l.
2554        float local_sum = 0.0f;
2555        for (uint s = tid; s < tile_n; s += 256) {
2556            local_sum += scores_sh[s];
2557        }
2558        local_sum = simd_sum(local_sum);
2559        if (simd_lane == 0) {
2560            sg_scratch[simd_id] = local_sum;
2561        }
2562        threadgroup_barrier(mem_flags::mem_threadgroup);
2563        if (tid == 0) {
2564            float tile_sum = 0.0f;
2565            for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2566            stats[1] += tile_sum;
2567        }
2568        threadgroup_barrier(mem_flags::mem_threadgroup);
2569
2570        // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2571        for (uint d = tid; d < head_dim; d += 256) {
2572            float acc = 0.0f;
2573            for (uint s = 0; s < tile_n; s++) {
2574                acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2575            }
2576            o_sh[d] += acc;
2577        }
2578        threadgroup_barrier(mem_flags::mem_threadgroup);
2579    }
2580
2581    // --- Normalize and write output ---
2582    float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2583    for (uint d = tid; d < head_dim; d += 256) {
2584        output_batch[out_off + d] = o_sh[d] * inv_l;
2585    }
2586}
2587
2588// ── attention_mma_flash_batch ─────────────────────────────────────────
2589// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2590// both Q·K^T and P·V.  Processes Q_BLOCK=8 tokens of one head per
2591// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2592// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2593// arithmetic.
2594//
2595// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2596// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2597// to attention_batch / attention_flash_batch otherwise.
2598//
2599// Online softmax recurrence is identical to attention_flash_batch but
2600// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2601constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2602constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2603constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2604
2605kernel void attention_mma_flash_batch(
2606    device const float* q_batch      [[buffer(0)]],
2607    device const float* k_cache      [[buffer(1)]],
2608    device const float* v_cache      [[buffer(2)]],
2609    device float* output_batch       [[buffer(3)]],
2610    constant uint& M                 [[buffer(4)]],
2611    constant uint& base_pos          [[buffer(5)]],
2612    constant uint& num_heads         [[buffer(6)]],
2613    constant uint& num_kv_heads      [[buffer(7)]],
2614    constant uint& head_dim          [[buffer(8)]],
2615    constant uint& q_stride          [[buffer(9)]],
2616    uint2 tgid [[threadgroup_position_in_grid]],
2617    uint tid [[thread_index_in_threadgroup]],
2618    uint simd_lane [[thread_index_in_simdgroup]],
2619    uint simd_id [[simdgroup_index_in_threadgroup]])
2620{
2621    uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2622    uint head = tgid.y;
2623    if (q_block_start >= M) return;
2624    uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2625
2626    uint kv_head = head / (num_heads / num_kv_heads);
2627    // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2628    // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2629    uint seq_len = base_pos + q_block_start + q_valid;
2630    float scale = fast::rsqrt(float(head_dim));
2631
2632    uint kv_stride = num_kv_heads * head_dim;
2633    uint kv_base_off = kv_head * head_dim;
2634
2635    // ── Threadgroup memory ──
2636    // q_sh:  [Q_BLOCK, head_dim] half  — Q tile, loaded once
2637    // k_sh:  [K_BLOCK, head_dim] half  — K tile, refreshed per kv_base iter
2638    // v_sh:  [K_BLOCK, head_dim] half  — V tile, refreshed per kv_base iter
2639    // s_sh:  [Q_BLOCK, K_BLOCK] float  — raw Q·K^T scores, then scaled+masked
2640    // p_sh:  [Q_BLOCK, K_BLOCK] half   — softmax probabilities (for P·V MMA)
2641    // o_sh:  [Q_BLOCK, head_dim] float — running output accumulator
2642    // m_sh:  [Q_BLOCK] float           — running max per Q row
2643    // l_sh:  [Q_BLOCK] float           — running softmax denominator per Q row
2644    // scratch: 4*Q_BLOCK floats        — per-row reduction scratch
2645    threadgroup half  q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2646    threadgroup half  k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2647    threadgroup half  v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2648    threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2649    threadgroup half  p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2650    threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2651    threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2652    threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2653    threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2654
2655    // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2656    uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2657    for (uint i = tid; i < qblock_elems; i += 128) {
2658        uint q = i / head_dim;
2659        uint d = i % head_dim;
2660        if (q < q_valid) {
2661            uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2662            q_sh[q * head_dim + d] = half(q_batch[q_off]);
2663        } else {
2664            q_sh[q * head_dim + d] = half(0);
2665        }
2666        o_sh[q * head_dim + d] = 0.0f;
2667    }
2668    if (tid < FLASH_MMA_Q_BLOCK) {
2669        m_sh[tid] = -INFINITY;
2670        l_sh[tid] = 0.0f;
2671    }
2672    threadgroup_barrier(mem_flags::mem_threadgroup);
2673
2674    // ── Stream K/V in K_BLOCK chunks ──
2675    for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2676        uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2677
2678        // Load K and V tile into TG memory (as half).
2679        uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2680        for (uint i = tid; i < kv_tile_elems; i += 128) {
2681            uint k_pos = i / head_dim;
2682            uint d = i % head_dim;
2683            if (k_pos < tile_n) {
2684                uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2685                k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2686                v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2687            } else {
2688                k_sh[k_pos * head_dim + d] = half(0);
2689                v_sh[k_pos * head_dim + d] = half(0);
2690            }
2691        }
2692        threadgroup_barrier(mem_flags::mem_threadgroup);
2693
2694        // ── Phase 1: S = Q @ K^T via MMA ──
2695        // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2696        // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2697        // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2698        {
2699            simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2700            uint dim_chunks = head_dim / 8;
2701            for (uint dc = 0; dc < dim_chunks; dc++) {
2702                simdgroup_matrix<half, 8, 8> A, B;
2703                // A = Q[0:8, dc*8 : dc*8+8]  (rows of Q, no transpose)
2704                simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2705                // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2706                // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2707                // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2708                // transpose=true, which places it in the register as K^T of that sub-block.
2709                simdgroup_load(B,
2710                    k_sh + (simd_id * 8) * head_dim + dc * 8,
2711                    head_dim, ulong2(0, 0), true);
2712                simdgroup_multiply_accumulate(C, A, B, C);
2713            }
2714            // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2715            simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2716        }
2717        threadgroup_barrier(mem_flags::mem_threadgroup);
2718
2719        // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2720        // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2721        uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2722        for (uint i = tid; i < s_elems; i += 128) {
2723            uint q = i / FLASH_MMA_K_BLOCK;
2724            uint k = i % FLASH_MMA_K_BLOCK;
2725            uint global_q = q_block_start + q;
2726            uint global_kv = kv_base + k;
2727            bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2728            s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2729        }
2730        threadgroup_barrier(mem_flags::mem_threadgroup);
2731
2732        // ── Phase 2b: per-row max via simdgroup reduction ──
2733        // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2734        // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2735        {
2736            uint row_base = simd_id * 2;
2737            for (uint qr = 0; qr < 2; qr++) {
2738                uint q = row_base + qr;
2739                float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2740                float row_max = simd_max(my);
2741                if (simd_lane == 0) {
2742                    scratch[q] = row_max;  // tile_max[q]
2743                }
2744            }
2745        }
2746        threadgroup_barrier(mem_flags::mem_threadgroup);
2747
2748        // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2749        if (tid < FLASH_MMA_Q_BLOCK) {
2750            uint q = tid;
2751            float m_old = m_sh[q];
2752            float tile_max = scratch[q];
2753            float m_new = max(m_old, tile_max);
2754            float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2755            m_sh[q] = m_new;
2756            l_sh[q] = l_sh[q] * alpha;
2757            // scratch[q]               = m_new   (for phase 2d)
2758            // scratch[Q_BLOCK + q]     = alpha   (for phase 3)
2759            scratch[q] = m_new;
2760            scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2761        }
2762        threadgroup_barrier(mem_flags::mem_threadgroup);
2763
2764        // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2765        {
2766            uint row_base = simd_id * 2;
2767            for (uint qr = 0; qr < 2; qr++) {
2768                uint q = row_base + qr;
2769                float m_new = scratch[q];
2770                float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2771                p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2772                float row_sum = simd_sum(p);
2773                if (simd_lane == 0) {
2774                    scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2775                }
2776            }
2777        }
2778        threadgroup_barrier(mem_flags::mem_threadgroup);
2779
2780        // ── Phase 2e: l_sh += tile_sum ──
2781        if (tid < FLASH_MMA_Q_BLOCK) {
2782            uint q = tid;
2783            l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2784        }
2785
2786        // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2787        threadgroup_barrier(mem_flags::mem_threadgroup);
2788        for (uint i = tid; i < qblock_elems; i += 128) {
2789            uint q = i / head_dim;
2790            float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2791            o_sh[i] *= alpha;
2792        }
2793        threadgroup_barrier(mem_flags::mem_threadgroup);
2794
2795        // ── Phase 4: O += P @ V via MMA ──
2796        // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2797        // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2798        // For head_dim=64: 16 dims/sg = 2 tiles.  head_dim=128: 32 dims/sg = 4 tiles.
2799        {
2800            uint dims_per_sg = head_dim / 4;       // 16 or 32
2801            uint tiles_per_sg = dims_per_sg / 8;   // 2 or 4
2802            uint sg_d_base = simd_id * dims_per_sg;
2803            for (uint t = 0; t < tiles_per_sg; t++) {
2804                uint d_base = sg_d_base + t * 8;
2805                simdgroup_matrix<float, 8, 8> O_acc;
2806                simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2807                uint k_chunks = FLASH_MMA_K_BLOCK / 8;  // 4
2808                for (uint kc = 0; kc < k_chunks; kc++) {
2809                    simdgroup_matrix<half, 8, 8> A, B;
2810                    simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2811                                   ulong2(0, 0), false);
2812                    simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2813                                   ulong2(0, 0), false);
2814                    simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2815                }
2816                simdgroup_store(O_acc, o_sh + d_base, head_dim);
2817            }
2818        }
2819        threadgroup_barrier(mem_flags::mem_threadgroup);
2820    }
2821
2822    // ── Finalize: O /= l, write to output ──
2823    for (uint i = tid; i < qblock_elems; i += 128) {
2824        uint q = i / head_dim;
2825        uint d = i % head_dim;
2826        if (q < q_valid) {
2827            float l = l_sh[q];
2828            float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2829            uint token_idx = q_block_start + q;
2830            uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2831            output_batch[out_off] = o_sh[i] * inv_l;
2832        }
2833    }
2834}
2835
2836// ── rope_qk_batch ─────────────────────────────────────────────────────
2837// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2838// launch + memory barrier per layer. Both Q and K live in the same
2839// qkv_data buffer at different offsets within each token's row.
2840// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2841kernel void rope_qk_batch(
2842    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2843    constant uint& M                 [[buffer(1)]],   // num tokens
2844    constant uint& base_pos          [[buffer(2)]],   // starting position
2845    constant uint& num_q_heads       [[buffer(3)]],
2846    constant uint& num_kv_heads      [[buffer(4)]],
2847    constant uint& head_dim          [[buffer(5)]],
2848    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2849    constant float& theta            [[buffer(7)]],
2850    uint id [[thread_position_in_grid]])
2851{
2852    uint half_dim = head_dim / 2;
2853    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2854    uint token = id / total_pairs;
2855    uint pair = id % total_pairs;
2856    if (token >= M) return;
2857
2858    uint pos = base_pos + token;
2859    uint q_pairs = num_q_heads * half_dim;
2860
2861    uint h, i, offset;
2862    if (pair < q_pairs) {
2863        // Q head
2864        h = pair / half_dim;
2865        i = pair % half_dim;
2866        offset = token * qkv_stride + h * head_dim + i * 2;
2867    } else {
2868        // K head
2869        uint kp = pair - q_pairs;
2870        h = kp / half_dim;
2871        i = kp % half_dim;
2872        uint k_start = num_q_heads * head_dim;
2873        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2874    }
2875
2876    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2877    float angle = float(pos) * freq;
2878    float cos_val = cos(angle);
2879    float sin_val = sin(angle);
2880
2881    float x0 = qkv_data[offset];
2882    float x1 = qkv_data[offset + 1];
2883    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2884    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2885}
2886
2887// ── copy_kv_both_batch ────────────────────────────────────────────────
2888// Fused K+V cache copy in a single dispatch: copies both K and V from
2889// the strided batch QKV buffer to their respective KV cache buffers.
2890// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2891kernel void copy_kv_both_batch(
2892    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride]
2893    device float* k_dst        [[buffer(1)]],  // K cache [max_seq, kv_dim]
2894    device float* v_dst        [[buffer(2)]],  // V cache [max_seq, kv_dim]
2895    constant uint& M           [[buffer(3)]],  // num tokens in batch
2896    constant uint& kv_dim      [[buffer(4)]],  // floats per KV vector
2897    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2898    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2899    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2900    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2901    uint id [[thread_position_in_grid]])
2902{
2903    // Total elements = M * kv_dim * 2 (K + V)
2904    uint total_kv = M * kv_dim;
2905    if (id >= total_kv * 2) return;
2906
2907    uint is_v = id / total_kv;        // 0 = K, 1 = V
2908    uint local_id = id % total_kv;
2909    uint token = local_id / kv_dim;
2910    uint d = local_id % kv_dim;
2911
2912    uint dst_off = (base_pos + token) * kv_dim + d;
2913    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2914
2915    if (is_v) {
2916        v_dst[dst_off] = src[src_off];
2917    } else {
2918        k_dst[dst_off] = src[src_off];
2919    }
2920}
2921"#
2922    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2923    .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2924}
2925
2926// ---------------------------------------------------------------------------
2927// model.rs generation
2928// ---------------------------------------------------------------------------
2929
2930fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2931    let mut code = String::with_capacity(48 * 1024);
2932    emit_model_header(&mut code, config)?;
2933    emit_metal_model_struct(&mut code, config)?;
2934    emit_layer_buffers_struct(&mut code, config)?;
2935    emit_metal_model_impl(&mut code, config)?;
2936    emit_helper_functions(&mut code)?;
2937    Ok(code)
2938}
2939
2940fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2941    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2942    writeln!(
2943        code,
2944        "//! Model: {} ({} layers, hidden={})",
2945        config.architecture, config.num_layers, config.hidden_size
2946    )?;
2947    writeln!(code, "//!")?;
2948    writeln!(
2949        code,
2950        "//! Uses native Metal compute pipelines via the metal crate."
2951    )?;
2952    writeln!(code)?;
2953    writeln!(code, "#![allow(dead_code)]")?;
2954    writeln!(code)?;
2955    writeln!(code, "use metal::*;")?;
2956    writeln!(code, "#[allow(unused_imports)]")?;
2957    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2958    writeln!(code, "use std::mem;")?;
2959    writeln!(code)?;
2960
2961    // Model constants
2962    writeln!(
2963        code,
2964        "// ── Model constants ──────────────────────────────────"
2965    )?;
2966    writeln!(
2967        code,
2968        "pub const HIDDEN_SIZE: usize = {};",
2969        config.hidden_size
2970    )?;
2971    writeln!(
2972        code,
2973        "pub const INTERMEDIATE_SIZE: usize = {};",
2974        config.intermediate_size
2975    )?;
2976    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2977    writeln!(
2978        code,
2979        "pub const NUM_HEADS: usize = {};",
2980        config.num_attention_heads
2981    )?;
2982    writeln!(
2983        code,
2984        "pub const NUM_KV_HEADS: usize = {};",
2985        config.num_kv_heads
2986    )?;
2987    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
2988    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
2989    let effective_seq_len = config.max_seq_len.min(4096);
2990    writeln!(
2991        code,
2992        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
2993        effective_seq_len, config.max_seq_len
2994    )?;
2995    writeln!(
2996        code,
2997        "pub const RMS_NORM_EPS: f32 = {:e};",
2998        config.rms_norm_eps
2999    )?;
3000    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3001    writeln!(
3002        code,
3003        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3004    )?;
3005    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3006    writeln!(code)?;
3007
3008    Ok(())
3009}
3010
3011fn emit_metal_model_struct(
3012    code: &mut String,
3013    config: &ModelConfig,
3014) -> Result<(), MetalCodegenError> {
3015    writeln!(
3016        code,
3017        "// ── MetalModel ──────────────────────────────────────────"
3018    )?;
3019    writeln!(code)?;
3020    writeln!(
3021        code,
3022        "/// Metal-accelerated transformer model for Apple Silicon."
3023    )?;
3024    writeln!(code, "///")?;
3025    writeln!(
3026        code,
3027        "/// Uses unified memory for zero-copy weight access and native Metal"
3028    )?;
3029    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3030    writeln!(code, "pub struct MetalModel {{")?;
3031    writeln!(code, "    device: Device,")?;
3032    writeln!(code, "    queue: CommandQueue,")?;
3033    writeln!(code)?;
3034    writeln!(code, "    // ── Compute pipelines ──")?;
3035    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
3036    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
3037    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
3038    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
3039    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
3040    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
3041    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
3042    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
3043    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
3044    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
3045    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
3046    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
3047    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
3048    writeln!(code)?;
3049    writeln!(code, "    // ── Batched prefill pipelines ──")?;
3050    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
3051    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
3052    writeln!(
3053        code,
3054        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3055    )?;
3056    writeln!(code, "    matmul_q8_mma_pipeline: ComputePipelineState,")?;
3057    writeln!(code, "    matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3058    writeln!(
3059        code,
3060        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3061    )?;
3062    writeln!(
3063        code,
3064        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3065    )?;
3066    writeln!(
3067        code,
3068        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3069    )?;
3070    if config.qkv_bias {
3071        writeln!(code, "    add_bias_batch_pipeline: ComputePipelineState,")?;
3072    }
3073    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
3074    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
3075    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
3076    writeln!(
3077        code,
3078        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
3079    )?;
3080    writeln!(
3081        code,
3082        "    add_inplace_batch_pipeline: ComputePipelineState,"
3083    )?;
3084    writeln!(
3085        code,
3086        "    copy_embedding_batch_pipeline: ComputePipelineState,"
3087    )?;
3088    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
3089    writeln!(
3090        code,
3091        "    attention_flash_batch_pipeline: ComputePipelineState,"
3092    )?;
3093    writeln!(
3094        code,
3095        "    attention_mma_flash_batch_pipeline: ComputePipelineState,"
3096    )?;
3097    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
3098    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
3099    writeln!(
3100        code,
3101        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
3102    )?;
3103    writeln!(code)?;
3104    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
3105    writeln!(
3106        code,
3107        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3108    )?;
3109    writeln!(code, "    embed_buf: Buffer,")?;
3110    writeln!(code)?;
3111    writeln!(code, "    /// Per-layer weight buffers")?;
3112    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
3113    writeln!(code)?;
3114    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
3115    writeln!(code, "    norm_buf: Buffer,")?;
3116    writeln!(code)?;
3117    writeln!(
3118        code,
3119        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3120    )?;
3121    writeln!(code, "    lm_head_buf: Buffer,")?;
3122    writeln!(code)?;
3123    writeln!(
3124        code,
3125        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
3126    )?;
3127    writeln!(code, "    hidden_buf: Buffer,")?;
3128    writeln!(code, "    residual_buf: Buffer,")?;
3129    writeln!(code, "    normed_buf: Buffer,")?;
3130    writeln!(
3131        code,
3132        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3133    )?;
3134    writeln!(code, "    qkv_buf: Buffer,")?;
3135    writeln!(code, "    attn_out_buf: Buffer,")?;
3136    writeln!(code, "    attn_proj_buf: Buffer,")?;
3137    writeln!(
3138        code,
3139        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3140    )?;
3141    writeln!(code, "    gate_up_buf: Buffer,")?;
3142    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
3143    writeln!(code, "    ffn_out_buf: Buffer,")?;
3144    writeln!(code, "    add_tmp_buf: Buffer,")?;
3145    writeln!(code, "    logits_buf: Buffer,")?;
3146    writeln!(code)?;
3147    writeln!(code, "    // ── Batched prefill working buffers ──")?;
3148    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3149    writeln!(code, "    batch_hidden_buf: Buffer,")?;
3150    writeln!(
3151        code,
3152        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3153    )?;
3154    writeln!(code, "    batch_residual_buf: Buffer,")?;
3155    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3156    writeln!(code, "    batch_qkv_buf: Buffer,")?;
3157    writeln!(
3158        code,
3159        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3160    )?;
3161    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
3162    writeln!(
3163        code,
3164        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3165    )?;
3166    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
3167    writeln!(
3168        code,
3169        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3170    )?;
3171    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
3172    writeln!(
3173        code,
3174        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3175    )?;
3176    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
3177    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3178    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
3179    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
3180    writeln!(code, "    batch_tokens_buf: Buffer,")?;
3181    writeln!(code, "    /// Positions buffer for batch RoPE")?;
3182    writeln!(code, "    batch_positions_buf: Buffer,")?;
3183    writeln!(code)?;
3184    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
3185    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
3186    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
3187    writeln!(code)?;
3188    writeln!(code, "    // ── Inference state ──")?;
3189    writeln!(code, "    pos: usize,")?;
3190    writeln!(code)?;
3191    writeln!(
3192        code,
3193        "    /// Previous command buffer for double-buffered prefill."
3194    )?;
3195    writeln!(
3196        code,
3197        "    /// While the GPU executes token N, the CPU can encode token N+1."
3198    )?;
3199    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
3200    writeln!(code, "}}")?;
3201    writeln!(code)?;
3202
3203    Ok(())
3204}
3205
3206fn emit_layer_buffers_struct(
3207    code: &mut String,
3208    config: &ModelConfig,
3209) -> Result<(), MetalCodegenError> {
3210    writeln!(
3211        code,
3212        "/// Per-layer weight buffers for attention and FFN projections."
3213    )?;
3214    writeln!(code, "struct LayerBuffers {{")?;
3215    writeln!(code, "    attn_norm: Buffer,")?;
3216    writeln!(
3217        code,
3218        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3219    )?;
3220    writeln!(code, "    qkv_weight: Buffer,")?;
3221    if config.qkv_bias {
3222        writeln!(
3223            code,
3224            "    /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3225        )?;
3226        writeln!(code, "    qkv_bias: Buffer,")?;
3227    }
3228    writeln!(code, "    o_weight: Buffer,")?;
3229    writeln!(code, "    ffn_norm: Buffer,")?;
3230    writeln!(
3231        code,
3232        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3233    )?;
3234    writeln!(code, "    gate_up_weight: Buffer,")?;
3235    writeln!(code, "    down_weight: Buffer,")?;
3236    writeln!(code, "}}")?;
3237    writeln!(code)?;
3238
3239    Ok(())
3240}
3241
3242fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3243    let hidden = config.hidden_size;
3244    let intermediate = config.intermediate_size;
3245    let _num_layers = config.num_layers;
3246    let num_heads = config.num_attention_heads;
3247    let num_kv_heads = config.num_kv_heads;
3248    let head_dim = config.head_dim;
3249    let vocab = config.vocab_size;
3250    let effective_seq_len = config.max_seq_len.min(4096);
3251    let is_q8 = config.dtype == DType::Q8_0;
3252    let is_q4 = config.dtype == DType::Q4_0;
3253    let kv_dim = num_kv_heads * head_dim;
3254
3255    writeln!(code, "impl MetalModel {{")?;
3256
3257    // ── new() ──
3258    writeln!(
3259        code,
3260        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3261    )?;
3262    writeln!(code, "    ///")?;
3263    writeln!(
3264        code,
3265        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
3266    )?;
3267    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
3268    writeln!(
3269        code,
3270        "        let device = Device::system_default().expect(\"no Metal device found\");"
3271    )?;
3272    writeln!(code, "        let queue = device.new_command_queue();")?;
3273    writeln!(code)?;
3274
3275    // Compile shaders
3276    writeln!(
3277        code,
3278        "        // Compile Metal shaders from embedded source"
3279    )?;
3280    writeln!(
3281        code,
3282        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
3283    )?;
3284    writeln!(
3285        code,
3286        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3287    )?;
3288    writeln!(
3289        code,
3290        "            .expect(\"failed to compile Metal shaders\");"
3291    )?;
3292    writeln!(code)?;
3293
3294    // Create compute pipelines
3295    writeln!(code, "        // Create compute pipelines")?;
3296    for (var, fn_name) in [
3297        ("matmul_pipeline", "matmul_vec"),
3298        ("matmul_q8_pipeline", "matmul_vec_q8"),
3299        ("matmul_q4_pipeline", "matmul_vec_q4"),
3300        ("rms_norm_pipeline", "rms_norm"),
3301        ("rope_pipeline", "rope"),
3302        ("softmax_pipeline", "softmax"),
3303        ("silu_mul_pipeline", "silu_mul"),
3304        ("silu_mul_fused_pipeline", "silu_mul_fused"),
3305        ("add_pipeline", "elementwise_add"),
3306        ("attention_pipeline", "attention"),
3307        ("add_inplace_pipeline", "add_inplace"),
3308        ("copy_pipeline", "copy_buffer"),
3309        ("copy_offset_pipeline", "copy_offset"),
3310        ("matmul_batch_pipeline", "matmul_vec_batch"),
3311        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3312        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3313        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3314        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3315        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3316        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3317        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3318        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3319        ("rms_norm_batch_pipeline", "rms_norm_batch"),
3320        ("rope_batch_pipeline", "rope_batch"),
3321        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3322        ("add_inplace_batch_pipeline", "add_inplace_batch"),
3323        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3324        ("attention_batch_pipeline", "attention_batch"),
3325        ("attention_flash_batch_pipeline", "attention_flash_batch"),
3326        ("attention_mma_flash_batch_pipeline", "attention_mma_flash_batch"),
3327        ("copy_kv_batch_pipeline", "copy_kv_batch"),
3328        ("rope_qk_batch_pipeline", "rope_qk_batch"),
3329        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3330    ] {
3331        writeln!(
3332            code,
3333            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3334        )?;
3335    }
3336    if config.qkv_bias {
3337        writeln!(
3338            code,
3339            "        let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3340        )?;
3341    }
3342    writeln!(code)?;
3343
3344    // Weight loading
3345    writeln!(
3346        code,
3347        "        // Load weights into Metal shared-memory buffers"
3348    )?;
3349    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
3350    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3351    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
3352    writeln!(code)?;
3353    writeln!(
3354        code,
3355        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
3356    )?;
3357    writeln!(code)?;
3358    writeln!(
3359        code,
3360        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3361    )?;
3362    writeln!(
3363        code,
3364        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3365    )?;
3366    writeln!(code, "            let byte_len = n * f32_size;")?;
3367    writeln!(code, "            let cur = cursor.get();")?;
3368    writeln!(
3369        code,
3370        "            let data = &weights[cur..cur + byte_len];"
3371    )?;
3372    writeln!(code, "            cursor.set(cur + byte_len);")?;
3373    writeln!(code, "            device.new_buffer_with_data(")?;
3374    writeln!(code, "                data.as_ptr() as *const _,")?;
3375    writeln!(code, "                byte_len as u64,")?;
3376    writeln!(
3377        code,
3378        "                MTLResourceOptions::StorageModeShared,"
3379    )?;
3380    writeln!(code, "            )")?;
3381    writeln!(code, "        }};")?;
3382    writeln!(code)?;
3383
3384    if is_q8 {
3385        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
3386        // We load them directly into Metal buffers without dequantizing,
3387        // and use the matmul_vec_q8 shader that operates on quantized data.
3388        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
3389        writeln!(
3390            code,
3391            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3392        )?;
3393        writeln!(
3394            code,
3395            "        // as raw bytes into a Metal buffer (no dequantization)."
3396        )?;
3397        writeln!(
3398            code,
3399            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3400        )?;
3401        writeln!(
3402            code,
3403            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3404        )?;
3405        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3406        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3407        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3408        writeln!(code, "            let cur = cursor.get();")?;
3409        writeln!(
3410            code,
3411            "            let data = &weights[cur..cur + total_raw];"
3412        )?;
3413        writeln!(code, "            cursor.set(cur + total_raw);")?;
3414        writeln!(code, "            device.new_buffer_with_data(")?;
3415        writeln!(code, "                data.as_ptr() as *const _,")?;
3416        writeln!(code, "                total_raw as u64,")?;
3417        writeln!(
3418            code,
3419            "                MTLResourceOptions::StorageModeShared,"
3420        )?;
3421        writeln!(code, "            )")?;
3422        writeln!(code, "        }};")?;
3423        writeln!(code)?;
3424        writeln!(
3425            code,
3426            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3427        )?;
3428        writeln!(
3429            code,
3430            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3431        )?;
3432        writeln!(
3433            code,
3434            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3435        )?;
3436        writeln!(
3437            code,
3438            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3439        )?;
3440        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3441        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3442        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3443        writeln!(code, "            let cur = cursor.get();")?;
3444        writeln!(
3445            code,
3446            "            let data = &weights[cur..cur + total_raw];"
3447        )?;
3448        writeln!(code, "            cursor.set(cur + total_raw);")?;
3449        writeln!(code, "            device.new_buffer_with_data(")?;
3450        writeln!(code, "                data.as_ptr() as *const _,")?;
3451        writeln!(code, "                total_raw as u64,")?;
3452        writeln!(
3453            code,
3454            "                MTLResourceOptions::StorageModeShared,"
3455        )?;
3456        writeln!(code, "            )")?;
3457        writeln!(code, "        }};")?;
3458        writeln!(code)?;
3459    }
3460
3461    if is_q4 {
3462        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3463        // We load them directly into Metal buffers without dequantizing,
3464        // and use the matmul_vec_q4 shader that operates on quantized data.
3465        // This quarters GPU memory usage vs f32 dequantization.
3466        writeln!(
3467            code,
3468            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3469        )?;
3470        writeln!(
3471            code,
3472            "        // as raw bytes into a Metal buffer (no dequantization)."
3473        )?;
3474        writeln!(
3475            code,
3476            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3477        )?;
3478        writeln!(
3479            code,
3480            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3481        )?;
3482        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3483        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3484        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3485        writeln!(code, "            let cur = cursor.get();")?;
3486        writeln!(
3487            code,
3488            "            let data = &weights[cur..cur + total_raw];"
3489        )?;
3490        writeln!(code, "            cursor.set(cur + total_raw);")?;
3491        writeln!(code, "            device.new_buffer_with_data(")?;
3492        writeln!(code, "                data.as_ptr() as *const _,")?;
3493        writeln!(code, "                total_raw as u64,")?;
3494        writeln!(
3495            code,
3496            "                MTLResourceOptions::StorageModeShared,"
3497        )?;
3498        writeln!(code, "            )")?;
3499        writeln!(code, "        }};")?;
3500        writeln!(code)?;
3501        writeln!(
3502            code,
3503            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3504        )?;
3505        writeln!(
3506            code,
3507            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3508        )?;
3509        writeln!(
3510            code,
3511            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3512        )?;
3513        writeln!(
3514            code,
3515            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3516        )?;
3517        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3518        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3519        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3520        writeln!(code, "            let cur = cursor.get();")?;
3521        writeln!(
3522            code,
3523            "            let data = &weights[cur..cur + total_raw];"
3524        )?;
3525        writeln!(code, "            cursor.set(cur + total_raw);")?;
3526        writeln!(code, "            device.new_buffer_with_data(")?;
3527        writeln!(code, "                data.as_ptr() as *const _,")?;
3528        writeln!(code, "                total_raw as u64,")?;
3529        writeln!(
3530            code,
3531            "                MTLResourceOptions::StorageModeShared,"
3532        )?;
3533        writeln!(code, "            )")?;
3534        writeln!(code, "        }};")?;
3535        writeln!(code)?;
3536    }
3537
3538    writeln!(
3539        code,
3540        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3541    )?;
3542    writeln!(code)?;
3543
3544    // Per-layer weights
3545    writeln!(
3546        code,
3547        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3548    )?;
3549    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3550
3551    // attn_norm is always f32
3552    writeln!(
3553        code,
3554        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3555    )?;
3556
3557    let qkv_rows = hidden + 2 * kv_dim;
3558    if is_q8 {
3559        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3560        writeln!(
3561            code,
3562            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3563        )?;
3564        if config.qkv_bias {
3565            writeln!(
3566                code,
3567                "            // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3568            )?;
3569            writeln!(
3570                code,
3571                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3572            )?;
3573        }
3574        writeln!(
3575            code,
3576            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3577        )?;
3578    } else if is_q4 {
3579        writeln!(
3580            code,
3581            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3582        )?;
3583        if config.qkv_bias {
3584            writeln!(
3585                code,
3586                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3587            )?;
3588        }
3589        writeln!(
3590            code,
3591            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3592        )?;
3593    } else {
3594        writeln!(
3595            code,
3596            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3597        )?;
3598        if config.qkv_bias {
3599            writeln!(
3600                code,
3601                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3602            )?;
3603        }
3604        writeln!(
3605            code,
3606            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3607        )?;
3608    }
3609
3610    // ffn_norm is always f32
3611    writeln!(
3612        code,
3613        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3614    )?;
3615
3616    let gate_up_rows = 2 * intermediate;
3617    if is_q8 {
3618        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3619        writeln!(
3620            code,
3621            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3622        )?;
3623        writeln!(
3624            code,
3625            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3626        )?;
3627    } else if is_q4 {
3628        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3629        writeln!(
3630            code,
3631            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3632        )?;
3633        writeln!(
3634            code,
3635            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3636        )?;
3637    } else {
3638        // Fused gate+up weight: read both as a single contiguous f32 buffer
3639        writeln!(
3640            code,
3641            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3642        )?;
3643        writeln!(
3644            code,
3645            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3646        )?;
3647    }
3648
3649    writeln!(code, "            layers.push(LayerBuffers {{")?;
3650    writeln!(code, "                attn_norm,")?;
3651    writeln!(code, "                qkv_weight,")?;
3652    if config.qkv_bias {
3653        writeln!(code, "                qkv_bias,")?;
3654    }
3655    writeln!(code, "                o_weight,")?;
3656    writeln!(code, "                ffn_norm,")?;
3657    writeln!(code, "                gate_up_weight,")?;
3658    writeln!(code, "                down_weight,")?;
3659    writeln!(code, "            }});")?;
3660    writeln!(code, "        }}")?;
3661    writeln!(code)?;
3662
3663    // final_norm is always f32
3664    writeln!(
3665        code,
3666        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3667    )?;
3668    writeln!(code)?;
3669
3670    // lm_head
3671    if is_q8 {
3672        writeln!(
3673            code,
3674            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3675        )?;
3676    } else if is_q4 {
3677        writeln!(
3678            code,
3679            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3680        )?;
3681    } else {
3682        writeln!(
3683            code,
3684            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3685        )?;
3686    }
3687    writeln!(code)?;
3688
3689    // Working buffers
3690    let hidden_bytes = hidden * 4;
3691    let _kv_dim_bytes = kv_dim * 4;
3692    let intermediate_bytes = intermediate * 4;
3693    let vocab_bytes = vocab * 4;
3694    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
3695
3696    writeln!(
3697        code,
3698        "        // Allocate working buffers (shared memory for zero-copy)"
3699    )?;
3700    writeln!(
3701        code,
3702        "        let opts = MTLResourceOptions::StorageModeShared;"
3703    )?;
3704    writeln!(
3705        code,
3706        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3707    )?;
3708    writeln!(
3709        code,
3710        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3711    )?;
3712    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3713    writeln!(
3714        code,
3715        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3716    )?;
3717    writeln!(
3718        code,
3719        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3720    )?;
3721    writeln!(
3722        code,
3723        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3724    )?;
3725    writeln!(
3726        code,
3727        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3728    )?;
3729    writeln!(
3730        code,
3731        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3732    )?;
3733    let gate_up_buf_bytes = 2 * intermediate * 4;
3734    writeln!(
3735        code,
3736        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3737    )?;
3738    writeln!(
3739        code,
3740        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3741    )?;
3742    writeln!(
3743        code,
3744        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3745    )?;
3746    writeln!(
3747        code,
3748        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3749    )?;
3750    writeln!(
3751        code,
3752        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3753    )?;
3754    writeln!(
3755        code,
3756        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3757    )?;
3758    writeln!(code)?;
3759
3760    // Batch prefill working buffers
3761    let batch_hidden_bytes = hidden * 4; // per-token
3762    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3763    let batch_gate_up_bytes = 2 * intermediate * 4;
3764    let batch_intermediate_bytes = intermediate * 4;
3765    writeln!(
3766        code,
3767        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3768    )?;
3769    writeln!(
3770        code,
3771        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3772    )?;
3773    writeln!(
3774        code,
3775        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3776    )?;
3777    writeln!(
3778        code,
3779        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3780    )?;
3781    writeln!(
3782        code,
3783        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3784    )?;
3785    writeln!(
3786        code,
3787        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3788    )?;
3789    writeln!(
3790        code,
3791        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3792    )?;
3793    writeln!(
3794        code,
3795        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3796    )?;
3797    writeln!(
3798        code,
3799        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3800    )?;
3801    writeln!(
3802        code,
3803        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3804    )?;
3805    writeln!(
3806        code,
3807        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3808    )?;
3809    writeln!(code)?;
3810
3811    // KV cache buffers
3812    writeln!(code, "        // KV cache buffers (per-layer)")?;
3813    writeln!(
3814        code,
3815        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3816    )?;
3817    writeln!(
3818        code,
3819        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3820    )?;
3821    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3822    writeln!(
3823        code,
3824        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3825    )?;
3826    writeln!(
3827        code,
3828        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3829    )?;
3830    writeln!(code, "        }}")?;
3831    writeln!(code)?;
3832
3833    writeln!(code, "        Self {{")?;
3834    writeln!(code, "            device,")?;
3835    writeln!(code, "            queue,")?;
3836    writeln!(code, "            matmul_pipeline,")?;
3837    writeln!(code, "            matmul_q8_pipeline,")?;
3838    writeln!(code, "            matmul_q4_pipeline,")?;
3839    writeln!(code, "            rms_norm_pipeline,")?;
3840    writeln!(code, "            rope_pipeline,")?;
3841    writeln!(code, "            softmax_pipeline,")?;
3842    writeln!(code, "            silu_mul_pipeline,")?;
3843    writeln!(code, "            silu_mul_fused_pipeline,")?;
3844    writeln!(code, "            add_pipeline,")?;
3845    writeln!(code, "            attention_pipeline,")?;
3846    writeln!(code, "            add_inplace_pipeline,")?;
3847    writeln!(code, "            copy_pipeline,")?;
3848    writeln!(code, "            copy_offset_pipeline,")?;
3849    writeln!(code, "            matmul_batch_pipeline,")?;
3850    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3851    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3852    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3853    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3854    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3855    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3856    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3857    if config.qkv_bias {
3858        writeln!(code, "            add_bias_batch_pipeline,")?;
3859    }
3860    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3861    writeln!(code, "            rms_norm_batch_pipeline,")?;
3862    writeln!(code, "            rope_batch_pipeline,")?;
3863    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3864    writeln!(code, "            add_inplace_batch_pipeline,")?;
3865    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3866    writeln!(code, "            attention_batch_pipeline,")?;
3867    writeln!(code, "            attention_flash_batch_pipeline,")?;
3868    writeln!(code, "            attention_mma_flash_batch_pipeline,")?;
3869    writeln!(code, "            copy_kv_batch_pipeline,")?;
3870    writeln!(code, "            rope_qk_batch_pipeline,")?;
3871    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3872    writeln!(code, "            embed_buf,")?;
3873    writeln!(code, "            layers,")?;
3874    writeln!(code, "            norm_buf,")?;
3875    writeln!(code, "            lm_head_buf,")?;
3876    writeln!(code, "            hidden_buf,")?;
3877    writeln!(code, "            residual_buf,")?;
3878    writeln!(code, "            normed_buf,")?;
3879    writeln!(code, "            qkv_buf,")?;
3880    writeln!(code, "            attn_out_buf,")?;
3881    writeln!(code, "            attn_proj_buf,")?;
3882    writeln!(code, "            gate_up_buf,")?;
3883    writeln!(code, "            ffn_hidden_buf,")?;
3884    writeln!(code, "            ffn_out_buf,")?;
3885    writeln!(code, "            add_tmp_buf,")?;
3886    writeln!(code, "            logits_buf,")?;
3887    writeln!(code, "            batch_hidden_buf,")?;
3888    writeln!(code, "            batch_residual_buf,")?;
3889    writeln!(code, "            batch_qkv_buf,")?;
3890    writeln!(code, "            batch_attn_out_buf,")?;
3891    writeln!(code, "            batch_attn_proj_buf,")?;
3892    writeln!(code, "            batch_gate_up_buf,")?;
3893    writeln!(code, "            batch_ffn_hidden_buf,")?;
3894    writeln!(code, "            batch_ffn_out_buf,")?;
3895    writeln!(code, "            batch_tokens_buf,")?;
3896    writeln!(code, "            batch_positions_buf,")?;
3897    writeln!(code, "            k_cache,")?;
3898    writeln!(code, "            v_cache,")?;
3899    writeln!(code, "            pos: 0,")?;
3900    writeln!(code, "            prev_cmd: None,")?;
3901    writeln!(code, "        }}")?;
3902    writeln!(code, "    }}")?;
3903    writeln!(code)?;
3904
3905    // ── forward() ──
3906    writeln!(
3907        code,
3908        "    /// Run the forward pass for a single token at the current position."
3909    )?;
3910    writeln!(code, "    ///")?;
3911    writeln!(
3912        code,
3913        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3914    )?;
3915    writeln!(code, "    ///")?;
3916    writeln!(
3917        code,
3918        "    /// All GPU operations are encoded into a single command buffer and"
3919    )?;
3920    writeln!(
3921        code,
3922        "    /// committed once at the end, avoiding per-operation synchronization."
3923    )?;
3924    writeln!(
3925        code,
3926        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3927    )?;
3928    writeln!(
3929        code,
3930        "        // Wait for any pending prefill command buffer"
3931    )?;
3932    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3933    writeln!(code, "            prev.wait_until_completed();")?;
3934    writeln!(code, "        }}")?;
3935    writeln!(code)?;
3936    writeln!(code, "        let pos = self.pos;")?;
3937    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3938    writeln!(code)?;
3939
3940    // Single compute encoder for the entire forward pass — no blit encoder
3941    // transitions. Copy operations use compute copy kernels instead of blits.
3942    let matmul_fn = if is_q8 {
3943        "dispatch_matmul_q8"
3944    } else if is_q4 {
3945        "dispatch_matmul_q4"
3946    } else {
3947        "dispatch_matmul"
3948    };
3949
3950    writeln!(
3951        code,
3952        "        // Single compute encoder for the entire forward pass (no blit transitions)"
3953    )?;
3954    writeln!(code, "        {{")?;
3955    writeln!(
3956        code,
3957        "            let enc = cmd.new_compute_command_encoder();"
3958    )?;
3959    writeln!(code)?;
3960
3961    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
3962    writeln!(
3963        code,
3964        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3965    )?;
3966    writeln!(
3967        code,
3968        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3969    )?;
3970    writeln!(
3971        code,
3972        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3973    )?;
3974    writeln!(
3975        code,
3976        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
3977        hidden * 4,
3978    )?;
3979    writeln!(code, "            unsafe {{")?;
3980    writeln!(
3981        code,
3982        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
3983    )?;
3984    writeln!(
3985        code,
3986        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3987    )?;
3988    writeln!(
3989        code,
3990        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
3991    )?;
3992    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
3993    writeln!(
3994        code,
3995        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3996    )?;
3997    writeln!(code, "                    hidden_ptr,")?;
3998    writeln!(code, "                    HIDDEN_SIZE,")?;
3999    writeln!(code, "                );")?;
4000    writeln!(
4001        code,
4002        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4003    )?;
4004    writeln!(code, "            }}")?;
4005    writeln!(code)?;
4006
4007    // 2. Transformer layers
4008    writeln!(code, "            // 2. Transformer layers")?;
4009    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4010    writeln!(code)?;
4011    let q_byte_offset = 0usize;
4012    let k_byte_offset = hidden * 4;
4013    let v_byte_offset = (hidden + kv_dim) * 4;
4014
4015    writeln!(
4016        code,
4017        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
4018    )?;
4019    writeln!(
4020        code,
4021        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4022    )?;
4023    writeln!(
4024        code,
4025        "                // Fused Q+K+V matmul: single dispatch for all three projections"
4026    )?;
4027    writeln!(
4028        code,
4029        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4030    )?;
4031    if config.qkv_bias {
4032        writeln!(
4033            code,
4034            "                // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4035        )?;
4036        writeln!(
4037            code,
4038            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4039        )?;
4040    }
4041    writeln!(
4042        code,
4043        "                // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
4044    )?;
4045    writeln!(
4046        code,
4047        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4048    )?;
4049    writeln!(
4050        code,
4051        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4052    )?;
4053    writeln!(code)?;
4054    writeln!(
4055        code,
4056        "                // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
4057    )?;
4058    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
4059    writeln!(
4060        code,
4061        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4062    )?;
4063    writeln!(
4064        code,
4065        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4066    )?;
4067    writeln!(code)?;
4068    writeln!(
4069        code,
4070        "                // Attention using Q from qkv_buf (offset 0)"
4071    )?;
4072    writeln!(
4073        code,
4074        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4075    )?;
4076    writeln!(
4077        code,
4078        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4079    )?;
4080    writeln!(
4081        code,
4082        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4083    )?;
4084    writeln!(
4085        code,
4086        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4087    )?;
4088    writeln!(
4089        code,
4090        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4091    )?;
4092    writeln!(
4093        code,
4094        "                // Fused gate+up matmul: single dispatch for both projections"
4095    )?;
4096    writeln!(
4097        code,
4098        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4099    )?;
4100    writeln!(
4101        code,
4102        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4103    )?;
4104    writeln!(
4105        code,
4106        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4107    )?;
4108    writeln!(
4109        code,
4110        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4111    )?;
4112    writeln!(code, "            }}")?;
4113    writeln!(code)?;
4114
4115    // 3. Final RMS norm + logits
4116    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
4117    writeln!(
4118        code,
4119        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4120    )?;
4121    writeln!(
4122        code,
4123        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4124    )?;
4125    writeln!(code)?;
4126    writeln!(code, "            enc.end_encoding();")?;
4127    writeln!(code, "        }}")?;
4128    writeln!(code)?;
4129
4130    // 5. Single commit + wait, then read back logits
4131    writeln!(
4132        code,
4133        "        // 5. Commit all GPU work and wait for completion"
4134    )?;
4135    writeln!(code, "        cmd.commit();")?;
4136    writeln!(code, "        cmd.wait_until_completed();")?;
4137    writeln!(code)?;
4138    writeln!(code, "        // 6. Read back logits from GPU")?;
4139    writeln!(code, "        let logits = unsafe {{")?;
4140    writeln!(
4141        code,
4142        "            let ptr = self.logits_buf.contents() as *const f32;"
4143    )?;
4144    writeln!(
4145        code,
4146        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4147    )?;
4148    writeln!(code, "        }};")?;
4149    writeln!(code)?;
4150    writeln!(code, "        self.pos += 1;")?;
4151    writeln!(code, "        logits")?;
4152    writeln!(code, "    }}")?;
4153    writeln!(code)?;
4154
4155    // ── forward_profile: instrumented forward with per-operation timing ──
4156    writeln!(
4157        code,
4158        "    /// Profiling forward pass that prints per-stage GPU timing."
4159    )?;
4160    writeln!(code, "    ///")?;
4161    writeln!(
4162        code,
4163        "    /// Each stage is committed and waited on separately so that GPU timestamps"
4164    )?;
4165    writeln!(
4166        code,
4167        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
4168    )?;
4169    writeln!(
4170        code,
4171        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
4172    )?;
4173    writeln!(
4174        code,
4175        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4176    )?;
4177    writeln!(code, "        use std::time::Instant;")?;
4178    writeln!(code)?;
4179    writeln!(
4180        code,
4181        "        // Wait for any pending prefill command buffer"
4182    )?;
4183    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4184    writeln!(code, "            prev.wait_until_completed();")?;
4185    writeln!(code, "        }}")?;
4186    writeln!(code)?;
4187    writeln!(code, "        let pos = self.pos;")?;
4188    writeln!(code)?;
4189
4190    // Stage: embedding (CPU, no GPU)
4191    writeln!(
4192        code,
4193        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
4194    )?;
4195    writeln!(code, "        let t_embed = Instant::now();")?;
4196    writeln!(code, "        unsafe {{")?;
4197    writeln!(
4198        code,
4199        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
4200    )?;
4201    writeln!(
4202        code,
4203        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4204    )?;
4205    writeln!(
4206        code,
4207        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
4208    )?;
4209    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
4210    writeln!(
4211        code,
4212        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4213    )?;
4214    writeln!(code, "                hidden_ptr,")?;
4215    writeln!(code, "                HIDDEN_SIZE,")?;
4216    writeln!(code, "            );")?;
4217    writeln!(
4218        code,
4219        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4220    )?;
4221    writeln!(code, "        }}")?;
4222    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
4223    writeln!(code)?;
4224
4225    // Stage: Transformer layers (all together on GPU)
4226    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
4227    writeln!(code, "        let t_layers = Instant::now();")?;
4228    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4229    writeln!(code, "        {{")?;
4230    writeln!(
4231        code,
4232        "            let enc = cmd.new_compute_command_encoder();"
4233    )?;
4234    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4235    writeln!(
4236        code,
4237        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4238    )?;
4239    writeln!(
4240        code,
4241        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4242    )?;
4243    if config.qkv_bias {
4244        writeln!(
4245            code,
4246            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4247        )?;
4248    }
4249    writeln!(
4250        code,
4251        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4252    )?;
4253    writeln!(
4254        code,
4255        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4256    )?;
4257    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
4258    writeln!(
4259        code,
4260        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4261    )?;
4262    writeln!(
4263        code,
4264        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4265    )?;
4266    writeln!(
4267        code,
4268        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4269    )?;
4270    writeln!(
4271        code,
4272        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4273    )?;
4274    writeln!(
4275        code,
4276        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4277    )?;
4278    writeln!(
4279        code,
4280        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4281    )?;
4282    writeln!(
4283        code,
4284        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4285    )?;
4286    writeln!(
4287        code,
4288        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4289    )?;
4290    writeln!(
4291        code,
4292        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4293    )?;
4294    writeln!(
4295        code,
4296        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4297    )?;
4298    writeln!(code, "            }}")?;
4299    writeln!(code, "            enc.end_encoding();")?;
4300    writeln!(code, "        }}")?;
4301    writeln!(code, "        cmd.commit();")?;
4302    writeln!(code, "        cmd.wait_until_completed();")?;
4303    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
4304    writeln!(code)?;
4305
4306    // Stage: Final norm + logits
4307    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
4308    writeln!(code, "        let t_logits = Instant::now();")?;
4309    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
4310    writeln!(code, "        {{")?;
4311    writeln!(
4312        code,
4313        "            let enc = cmd2.new_compute_command_encoder();"
4314    )?;
4315    writeln!(
4316        code,
4317        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
4318    )?;
4319    writeln!(
4320        code,
4321        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4322    )?;
4323    writeln!(code, "            enc.end_encoding();")?;
4324    writeln!(code, "        }}")?;
4325    writeln!(code, "        cmd2.commit();")?;
4326    writeln!(code, "        cmd2.wait_until_completed();")?;
4327    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
4328    writeln!(code)?;
4329
4330    // Print profile results
4331    writeln!(
4332        code,
4333        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4334    )?;
4335    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
4336    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
4337    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
4338    writeln!(
4339        code,
4340        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4341    )?;
4342    writeln!(code)?;
4343
4344    // Read back logits
4345    writeln!(code, "        let logits = unsafe {{")?;
4346    writeln!(
4347        code,
4348        "            let ptr = self.logits_buf.contents() as *const f32;"
4349    )?;
4350    writeln!(
4351        code,
4352        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4353    )?;
4354    writeln!(code, "        }};")?;
4355    writeln!(code)?;
4356    writeln!(code, "        self.pos += 1;")?;
4357    writeln!(code, "        logits")?;
4358    writeln!(code, "    }}")?;
4359    writeln!(code)?;
4360
4361    // ── forward_prefill: single-token async forward (backward compat) ──
4362    writeln!(
4363        code,
4364        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
4365    )?;
4366    writeln!(code, "    ///")?;
4367    writeln!(
4368        code,
4369        "    /// Commits the command buffer without waiting, enabling double-buffered"
4370    )?;
4371    writeln!(
4372        code,
4373        "    /// execution: GPU processes token N while CPU encodes token N+1."
4374    )?;
4375    writeln!(
4376        code,
4377        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
4378    )?;
4379    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
4380    writeln!(code, "    }}")?;
4381    writeln!(code)?;
4382
4383    // ── forward_prefill_batch: batched prefill for multiple tokens ──
4384    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
4385    let batch_matmul_fn = if is_q8 {
4386        "dispatch_matmul_q8_batch"
4387    } else if is_q4 {
4388        "dispatch_matmul_q4_batch"
4389    } else {
4390        "dispatch_matmul_batch"
4391    };
4392
4393    writeln!(
4394        code,
4395        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4396    )?;
4397    writeln!(code, "    ///")?;
4398    writeln!(
4399        code,
4400        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4401    )?;
4402    writeln!(
4403        code,
4404        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
4405    )?;
4406    writeln!(
4407        code,
4408        "    /// This provides significant speedup during prompt prefill."
4409    )?;
4410    writeln!(
4411        code,
4412        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4413    )?;
4414    writeln!(code, "        if tokens.is_empty() {{ return; }}")?;
4415    writeln!(
4416        code,
4417        "        // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4418    )?;
4419    writeln!(
4420        code,
4421        "        // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4422    )?;
4423    writeln!(
4424        code,
4425        "        // longer than that must be processed iteratively.  The KV cache"
4426    )?;
4427    writeln!(code, "        // carries state across chunks via self.pos.")?;
4428    writeln!(
4429        code,
4430        "        for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4431    )?;
4432    writeln!(code, "        let m = chunk.len();")?;
4433    writeln!(code, "        if m == 0 {{ continue; }}")?;
4434    writeln!(code, "        let start_pos = self.pos;")?;
4435    writeln!(code)?;
4436    writeln!(code, "        // Wait for any pending command buffer")?;
4437    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
4438    writeln!(code, "            prev.wait_until_completed();")?;
4439    writeln!(code, "        }}")?;
4440    writeln!(code)?;
4441
4442    // Upload token IDs and positions to GPU
4443    writeln!(
4444        code,
4445        "        // Upload token IDs and positions to GPU buffers"
4446    )?;
4447    writeln!(code, "        unsafe {{")?;
4448    writeln!(
4449        code,
4450        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4451    )?;
4452    writeln!(
4453        code,
4454        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4455    )?;
4456    writeln!(code, "            for i in 0..m {{")?;
4457    writeln!(code, "                *tok_ptr.add(i) = chunk[i];")?;
4458    writeln!(
4459        code,
4460        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
4461    )?;
4462    writeln!(code, "            }}")?;
4463    writeln!(code, "        }}")?;
4464    writeln!(code)?;
4465
4466    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4467    writeln!(code, "        {{")?;
4468    writeln!(
4469        code,
4470        "            let enc = cmd.new_compute_command_encoder();"
4471    )?;
4472    writeln!(code)?;
4473
4474    // 1. Batch embedding lookup
4475    writeln!(
4476        code,
4477        "            // 1. Batch embedding lookup: copy all token embeddings at once"
4478    )?;
4479    writeln!(
4480        code,
4481        "            self.dispatch_copy_embedding_batch(&enc, m);"
4482    )?;
4483    // Copy batch_hidden -> batch_residual
4484    writeln!(
4485        code,
4486        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4487    )?;
4488    writeln!(code)?;
4489
4490    // 2. Transformer layers
4491    writeln!(code, "            // 2. Transformer layers")?;
4492    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4493    writeln!(code)?;
4494
4495    // Batch RMS norm: residual -> hidden (batched)
4496    writeln!(
4497        code,
4498        "                // Batch RMS norm: batch_residual -> batch_hidden"
4499    )?;
4500    writeln!(
4501        code,
4502        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4503    )?;
4504
4505    // Batch QKV matmul
4506    writeln!(
4507        code,
4508        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4509    )?;
4510    writeln!(
4511        code,
4512        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4513    )?;
4514    if config.qkv_bias {
4515        writeln!(
4516            code,
4517            "                // Qwen2: broadcast-add QKV bias across all M tokens."
4518        )?;
4519        writeln!(
4520            code,
4521            "                self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4522        )?;
4523    }
4524    writeln!(code)?;
4525
4526    // Fused RoPE on Q+K portions in a single dispatch
4527    let k_float_offset = hidden;
4528    writeln!(
4529        code,
4530        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4531    )?;
4532    writeln!(
4533        code,
4534        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4535    )?;
4536    writeln!(code)?;
4537
4538    // Fused KV cache update: copy both K and V in a single dispatch
4539    let v_float_offset = hidden + kv_dim;
4540    writeln!(
4541        code,
4542        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4543    )?;
4544    writeln!(
4545        code,
4546        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4547    )?;
4548    writeln!(code)?;
4549
4550    // Batched causal attention: ONE dispatch for all M tokens
4551    writeln!(
4552        code,
4553        "                // Batched causal attention: one dispatch for all M tokens"
4554    )?;
4555    writeln!(
4556        code,
4557        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4558    )?;
4559    writeln!(code)?;
4560
4561    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4562    writeln!(code, "                // Batched O projection")?;
4563    writeln!(
4564        code,
4565        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4566    )?;
4567    writeln!(code)?;
4568
4569    // Batch add: residual += attn_proj for all tokens
4570    writeln!(
4571        code,
4572        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4573    )?;
4574    writeln!(code)?;
4575
4576    // Batch FFN
4577    writeln!(
4578        code,
4579        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4580    )?;
4581    writeln!(
4582        code,
4583        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4584    )?;
4585    writeln!(
4586        code,
4587        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4588    )?;
4589    writeln!(
4590        code,
4591        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4592    )?;
4593    writeln!(
4594        code,
4595        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4596    )?;
4597    writeln!(
4598        code,
4599        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4600    )?;
4601    writeln!(code, "            }}")?;
4602    writeln!(code)?;
4603
4604    // Copy last token's residual to single-token residual_buf for next forward() call
4605    writeln!(
4606        code,
4607        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4608    )?;
4609    writeln!(
4610        code,
4611        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4612    )?;
4613    writeln!(code)?;
4614    writeln!(code, "            enc.end_encoding();")?;
4615    writeln!(code, "        }}")?;
4616    writeln!(code)?;
4617
4618    writeln!(code, "        cmd.commit();")?;
4619    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4620    writeln!(code, "        self.pos += m;")?;
4621    writeln!(code, "        }}  // end for chunk")?;
4622    writeln!(code, "    }}")?;
4623    writeln!(code)?;
4624
4625    // ── reset() — rewind KV cache position for new inference requests ──
4626    writeln!(
4627        code,
4628        "    /// Reset the model state for a new inference request."
4629    )?;
4630    writeln!(code, "    pub fn reset(&mut self) {{")?;
4631    writeln!(code, "        self.pos = 0;")?;
4632    writeln!(code, "        self.prev_cmd = None;")?;
4633    writeln!(code, "    }}")?;
4634    writeln!(code)?;
4635
4636    // ── Private dispatch helpers (all take a shared compute encoder) ──
4637    writeln!(
4638        code,
4639        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4640    )?;
4641    writeln!(
4642        code,
4643        "    // These methods set pipeline state + buffers + dispatch on an existing"
4644    )?;
4645    writeln!(
4646        code,
4647        "    // encoder, avoiding per-operation encoder creation overhead."
4648    )?;
4649    writeln!(code)?;
4650
4651    // dispatch_rms_norm
4652    writeln!(
4653        code,
4654        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4655    )?;
4656    writeln!(
4657        code,
4658        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4659    )?;
4660    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4661    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4662    writeln!(
4663        code,
4664        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4665    )?;
4666    writeln!(
4667        code,
4668        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4669    )?;
4670    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4671    writeln!(
4672        code,
4673        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4674    )?;
4675    writeln!(
4676        code,
4677        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4678    )?;
4679    writeln!(
4680        code,
4681        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4682    )?;
4683    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4684    writeln!(
4685        code,
4686        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4687    )?;
4688    writeln!(
4689        code,
4690        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4691    )?;
4692    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4693    writeln!(code, "    }}")?;
4694    writeln!(code)?;
4695
4696    // dispatch_matmul
4697    writeln!(
4698        code,
4699        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4700    )?;
4701    writeln!(
4702        code,
4703        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4704    )?;
4705    writeln!(code, "        let r: u32 = rows as u32;")?;
4706    writeln!(code, "        let c: u32 = cols as u32;")?;
4707    writeln!(
4708        code,
4709        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4710    )?;
4711    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4712    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4713    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4714    writeln!(
4715        code,
4716        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4717    )?;
4718    writeln!(
4719        code,
4720        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4721    )?;
4722    writeln!(
4723        code,
4724        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4725    )?;
4726    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4727    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4728    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4729    writeln!(
4730        code,
4731        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4732    )?;
4733    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4734    writeln!(code, "    }}")?;
4735    writeln!(code)?;
4736
4737    // dispatch_matmul_q8
4738    writeln!(
4739        code,
4740        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4741    )?;
4742    writeln!(
4743        code,
4744        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4745    )?;
4746    writeln!(
4747        code,
4748        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4749    )?;
4750    writeln!(code, "        let r: u32 = rows as u32;")?;
4751    writeln!(code, "        let c: u32 = cols as u32;")?;
4752    writeln!(
4753        code,
4754        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4755    )?;
4756    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4757    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4758    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4759    writeln!(
4760        code,
4761        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4762    )?;
4763    writeln!(
4764        code,
4765        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4766    )?;
4767    writeln!(
4768        code,
4769        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4770    )?;
4771    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4772    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4773    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4774    writeln!(
4775        code,
4776        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4777    )?;
4778    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4779    writeln!(code, "    }}")?;
4780    writeln!(code)?;
4781
4782    // dispatch_matmul_q4
4783    writeln!(
4784        code,
4785        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4786    )?;
4787    writeln!(
4788        code,
4789        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4790    )?;
4791    writeln!(
4792        code,
4793        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4794    )?;
4795    writeln!(code, "        let r: u32 = rows as u32;")?;
4796    writeln!(code, "        let c: u32 = cols as u32;")?;
4797    writeln!(
4798        code,
4799        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4800    )?;
4801    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4802    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4803    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4804    writeln!(
4805        code,
4806        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4807    )?;
4808    writeln!(
4809        code,
4810        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4811    )?;
4812    writeln!(
4813        code,
4814        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4815    )?;
4816    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4817    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4818    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4819    writeln!(
4820        code,
4821        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4822    )?;
4823    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4824    writeln!(code, "    }}")?;
4825    writeln!(code)?;
4826
4827    // dispatch_rope
4828    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4829    writeln!(
4830        code,
4831        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4832    )?;
4833    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4834    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4835    writeln!(code, "        let p: u32 = pos as u32;")?;
4836    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4837    writeln!(
4838        code,
4839        "        let total_pairs = num_heads * (head_dim / 2);"
4840    )?;
4841    writeln!(
4842        code,
4843        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4844    )?;
4845    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4846    writeln!(
4847        code,
4848        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4849    )?;
4850    writeln!(
4851        code,
4852        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4853    )?;
4854    writeln!(
4855        code,
4856        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4857    )?;
4858    writeln!(
4859        code,
4860        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4861    )?;
4862    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4863    writeln!(
4864        code,
4865        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4866    )?;
4867    writeln!(
4868        code,
4869        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4870    )?;
4871    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4872    writeln!(code, "    }}")?;
4873    writeln!(code)?;
4874
4875    // dispatch_rope_offset
4876    writeln!(
4877        code,
4878        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4879    )?;
4880    writeln!(
4881        code,
4882        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4883    )?;
4884    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4885    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4886    writeln!(code, "        let p: u32 = pos as u32;")?;
4887    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4888    writeln!(
4889        code,
4890        "        let total_pairs = num_heads * (head_dim / 2);"
4891    )?;
4892    writeln!(
4893        code,
4894        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4895    )?;
4896    writeln!(
4897        code,
4898        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4899    )?;
4900    writeln!(
4901        code,
4902        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4903    )?;
4904    writeln!(
4905        code,
4906        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4907    )?;
4908    writeln!(
4909        code,
4910        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4911    )?;
4912    writeln!(
4913        code,
4914        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4915    )?;
4916    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4917    writeln!(
4918        code,
4919        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4920    )?;
4921    writeln!(
4922        code,
4923        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4924    )?;
4925    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4926    writeln!(code, "    }}")?;
4927    writeln!(code)?;
4928
4929    // dispatch_attention
4930    writeln!(
4931        code,
4932        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4933    )?;
4934    writeln!(
4935        code,
4936        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4937    )?;
4938    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4939    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4940    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4941    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4942    writeln!(
4943        code,
4944        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4945    )?;
4946    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4947    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4948    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4949    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4950    writeln!(
4951        code,
4952        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4953    )?;
4954    writeln!(
4955        code,
4956        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4957    )?;
4958    writeln!(
4959        code,
4960        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4961    )?;
4962    writeln!(
4963        code,
4964        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4965    )?;
4966    writeln!(
4967        code,
4968        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4969    )?;
4970    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4971    writeln!(
4972        code,
4973        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4974    )?;
4975    writeln!(
4976        code,
4977        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4978    )?;
4979    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4980    writeln!(code, "    }}")?;
4981    writeln!(code)?;
4982
4983    // dispatch_attention_offset
4984    writeln!(
4985        code,
4986        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4987    )?;
4988    writeln!(
4989        code,
4990        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4991    )?;
4992    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4993    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4994    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4995    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4996    writeln!(
4997        code,
4998        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4999    )?;
5000    writeln!(
5001        code,
5002        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5003    )?;
5004    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5005    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5006    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5007    writeln!(
5008        code,
5009        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5010    )?;
5011    writeln!(
5012        code,
5013        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5014    )?;
5015    writeln!(
5016        code,
5017        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5018    )?;
5019    writeln!(
5020        code,
5021        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5022    )?;
5023    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5024    writeln!(
5025        code,
5026        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5027    )?;
5028    writeln!(
5029        code,
5030        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5031    )?;
5032    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5033    writeln!(code, "    }}")?;
5034    writeln!(code)?;
5035
5036    // dispatch_silu_mul
5037    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
5038    writeln!(
5039        code,
5040        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5041    )?;
5042    writeln!(code, "        let count: u32 = n as u32;")?;
5043    writeln!(
5044        code,
5045        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5046    )?;
5047    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
5048    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
5049    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5050    writeln!(
5051        code,
5052        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5053    )?;
5054    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5055    writeln!(
5056        code,
5057        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5058    )?;
5059    writeln!(
5060        code,
5061        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5062    )?;
5063    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5064    writeln!(code, "    }}")?;
5065    writeln!(code)?;
5066
5067    // dispatch_silu_mul_fused
5068    writeln!(
5069        code,
5070        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5071    )?;
5072    writeln!(
5073        code,
5074        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
5075    )?;
5076    writeln!(
5077        code,
5078        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5079    )?;
5080    writeln!(code, "        let count: u32 = n as u32;")?;
5081    writeln!(
5082        code,
5083        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5084    )?;
5085    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5086    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5087    writeln!(
5088        code,
5089        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5090    )?;
5091    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5092    writeln!(
5093        code,
5094        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5095    )?;
5096    writeln!(
5097        code,
5098        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5099    )?;
5100    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5101    writeln!(code, "    }}")?;
5102    writeln!(code)?;
5103
5104    // dispatch_copy (simple src -> dst copy via compute kernel)
5105    writeln!(
5106        code,
5107        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5108    )?;
5109    writeln!(
5110        code,
5111        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5112    )?;
5113    writeln!(code, "        let n: u32 = count as u32;")?;
5114    writeln!(
5115        code,
5116        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5117    )?;
5118    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5119    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5120    writeln!(
5121        code,
5122        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5123    )?;
5124    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5125    writeln!(
5126        code,
5127        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5128    )?;
5129    writeln!(
5130        code,
5131        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5132    )?;
5133    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5134    writeln!(code, "    }}")?;
5135    writeln!(code)?;
5136
5137    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
5138    writeln!(
5139        code,
5140        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5141    )?;
5142    writeln!(
5143        code,
5144        "    /// Used for embedding table lookup (copy a specific row)."
5145    )?;
5146    writeln!(
5147        code,
5148        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5149    )?;
5150    writeln!(code, "        let off: u32 = src_offset as u32;")?;
5151    writeln!(code, "        let n: u32 = count as u32;")?;
5152    writeln!(
5153        code,
5154        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5155    )?;
5156    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5157    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5158    writeln!(
5159        code,
5160        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5161    )?;
5162    writeln!(
5163        code,
5164        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5165    )?;
5166    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5167    writeln!(
5168        code,
5169        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5170    )?;
5171    writeln!(
5172        code,
5173        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5174    )?;
5175    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5176    writeln!(code, "    }}")?;
5177    writeln!(code)?;
5178
5179    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
5180    writeln!(
5181        code,
5182        "    /// Dispatch copy from source at byte offset to destination at float offset."
5183    )?;
5184    writeln!(
5185        code,
5186        "    /// Used for KV cache updates from fused QKV buffer."
5187    )?;
5188    writeln!(
5189        code,
5190        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5191    )?;
5192    writeln!(code, "        let n: u32 = count as u32;")?;
5193    writeln!(
5194        code,
5195        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5196    )?;
5197    writeln!(
5198        code,
5199        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5200    )?;
5201    writeln!(
5202        code,
5203        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5204    )?;
5205    writeln!(
5206        code,
5207        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5208    )?;
5209    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5210    writeln!(
5211        code,
5212        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5213    )?;
5214    writeln!(
5215        code,
5216        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5217    )?;
5218    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5219    writeln!(code, "    }}")?;
5220    writeln!(code)?;
5221
5222    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
5223    writeln!(
5224        code,
5225        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5226    )?;
5227    writeln!(
5228        code,
5229        "    /// Used for KV cache updates (write to a specific position in the cache)."
5230    )?;
5231    writeln!(
5232        code,
5233        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5234    )?;
5235    writeln!(code, "        let n: u32 = count as u32;")?;
5236    writeln!(
5237        code,
5238        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5239    )?;
5240    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5241    writeln!(
5242        code,
5243        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5244    )?;
5245    writeln!(
5246        code,
5247        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5248    )?;
5249    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5250    writeln!(
5251        code,
5252        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5253    )?;
5254    writeln!(
5255        code,
5256        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5257    )?;
5258    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5259    writeln!(code, "    }}")?;
5260    writeln!(code)?;
5261
5262    // dispatch_add_inplace (residual connection, no blit needed)
5263    writeln!(
5264        code,
5265        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5266    )?;
5267    writeln!(
5268        code,
5269        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5270    )?;
5271    writeln!(code, "        let count: u32 = n as u32;")?;
5272    writeln!(
5273        code,
5274        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5275    )?;
5276    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5277    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5278    writeln!(
5279        code,
5280        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5281    )?;
5282    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5283    writeln!(
5284        code,
5285        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5286    )?;
5287    writeln!(
5288        code,
5289        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5290    )?;
5291    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5292    writeln!(code, "    }}")?;
5293    writeln!(code)?;
5294
5295    // ── Batched prefill dispatch helpers ──
5296    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
5297    writeln!(code)?;
5298
5299    // dispatch_copy_embedding_batch
5300    writeln!(
5301        code,
5302        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
5303    )?;
5304    writeln!(
5305        code,
5306        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5307    )?;
5308    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
5309    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5310    writeln!(
5311        code,
5312        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5313    )?;
5314    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5315    writeln!(
5316        code,
5317        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5318    )?;
5319    writeln!(
5320        code,
5321        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5322    )?;
5323    writeln!(
5324        code,
5325        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5326    )?;
5327    writeln!(
5328        code,
5329        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5330    )?;
5331    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
5332    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5333    writeln!(
5334        code,
5335        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5336    )?;
5337    writeln!(
5338        code,
5339        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5340    )?;
5341    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5342    writeln!(code, "    }}")?;
5343    writeln!(code)?;
5344
5345    // dispatch_rms_norm_batch
5346    writeln!(
5347        code,
5348        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
5349    )?;
5350    writeln!(
5351        code,
5352        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5353    )?;
5354    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
5355    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
5356    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5357    writeln!(
5358        code,
5359        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5360    )?;
5361    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
5362    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
5363    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5364    writeln!(
5365        code,
5366        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5367    )?;
5368    writeln!(
5369        code,
5370        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5371    )?;
5372    writeln!(
5373        code,
5374        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5375    )?;
5376    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5377    writeln!(
5378        code,
5379        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5380    )?;
5381    writeln!(
5382        code,
5383        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5384    )?;
5385    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5386    writeln!(code, "    }}")?;
5387    writeln!(code)?;
5388
5389    // dispatch_matmul_batch (f32)
5390    writeln!(
5391        code,
5392        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5393    )?;
5394    writeln!(
5395        code,
5396        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5397    )?;
5398    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5399    writeln!(code, "        let r: u32 = rows as u32;")?;
5400    writeln!(code, "        let c: u32 = cols as u32;")?;
5401    writeln!(
5402        code,
5403        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5404    )?;
5405    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5406    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5407    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5408    writeln!(
5409        code,
5410        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5411    )?;
5412    writeln!(
5413        code,
5414        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5415    )?;
5416    writeln!(
5417        code,
5418        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5419    )?;
5420    writeln!(
5421        code,
5422        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
5423    )?;
5424    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5425    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5426    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5427    writeln!(
5428        code,
5429        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5430    )?;
5431    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5432    writeln!(code, "    }}")?;
5433    writeln!(code)?;
5434
5435    // dispatch_matmul_q8_batch
5436    writeln!(
5437        code,
5438        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5439    )?;
5440    writeln!(code, "    ///")?;
5441    writeln!(
5442        code,
5443        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5444    )?;
5445    writeln!(
5446        code,
5447        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5448    )?;
5449    writeln!(
5450        code,
5451        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5452    )?;
5453    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5454    writeln!(code, "        let r: u32 = rows as u32;")?;
5455    writeln!(code, "        let c: u32 = cols as u32;")?;
5456    writeln!(
5457        code,
5458        "        // Tile sizes must match the Metal shader constants."
5459    )?;
5460    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
5461    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
5462    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
5463    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
5464    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
5465    writeln!(
5466        code,
5467        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
5468    )?;
5469    writeln!(
5470        code,
5471        "        // Prefer the large 32×32 tile when the problem supports it — halves"
5472    )?;
5473    writeln!(
5474        code,
5475        "        // dispatch count and reuses each weight load across 32 tokens."
5476    )?;
5477    writeln!(
5478        code,
5479        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5480    )?;
5481    writeln!(
5482        code,
5483        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5484    )?;
5485    writeln!(
5486        code,
5487        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
5488    )?;
5489    writeln!(
5490        code,
5491        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5492    )?;
5493    writeln!(
5494        code,
5495        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
5496    )?;
5497    writeln!(
5498        code,
5499        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5500    )?;
5501    writeln!(
5502        code,
5503        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5504    )?;
5505    writeln!(
5506        code,
5507        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5508    )?;
5509    writeln!(
5510        code,
5511        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5512    )?;
5513    writeln!(code, "            let use_h4 = cols >= 2048;")?;
5514    writeln!(code, "            let pipe = if use_h4 {{")?;
5515    writeln!(code, "                if num_tokens >= 256 {{")?;
5516    writeln!(
5517        code,
5518        "                    &self.matmul_q8_mma32_hh4_pipeline"
5519    )?;
5520    writeln!(code, "                }} else {{")?;
5521    writeln!(
5522        code,
5523        "                    &self.matmul_q8_mma32_h4_pipeline"
5524    )?;
5525    writeln!(code, "                }}")?;
5526    writeln!(code, "            }} else {{")?;
5527    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
5528    writeln!(code, "            }};")?;
5529    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
5530    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5531    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5532    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5533    writeln!(
5534        code,
5535        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5536    )?;
5537    writeln!(
5538        code,
5539        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5540    )?;
5541    writeln!(
5542        code,
5543        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5544    )?;
5545    writeln!(code, "            let row_tgs = rows / MMA32_ROW_TILE;")?;
5546    writeln!(
5547        code,
5548        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5549    )?;
5550    writeln!(
5551        code,
5552        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5553    )?;
5554    writeln!(
5555        code,
5556        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5557    )?;
5558    writeln!(
5559        code,
5560        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5561    )?;
5562    writeln!(
5563        code,
5564        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5565    )?;
5566    writeln!(
5567        code,
5568        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5569    )?;
5570    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5571    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5572    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5573    writeln!(
5574        code,
5575        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5576    )?;
5577    writeln!(
5578        code,
5579        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5580    )?;
5581    writeln!(
5582        code,
5583        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5584    )?;
5585    writeln!(code, "            let row_tgs = rows / MMA_ROW_TILE;")?;
5586    writeln!(
5587        code,
5588        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5589    )?;
5590    writeln!(
5591        code,
5592        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5593    )?;
5594    writeln!(
5595        code,
5596        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5597    )?;
5598    writeln!(
5599        code,
5600        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5601    )?;
5602    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5603    writeln!(
5604        code,
5605        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5606    )?;
5607    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5608    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5609    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5610    writeln!(
5611        code,
5612        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5613    )?;
5614    writeln!(
5615        code,
5616        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5617    )?;
5618    writeln!(
5619        code,
5620        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5621    )?;
5622    writeln!(
5623        code,
5624        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5625    )?;
5626    writeln!(
5627        code,
5628        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5629    )?;
5630    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5631    writeln!(
5632        code,
5633        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5634    )?;
5635    writeln!(
5636        code,
5637        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5638    )?;
5639    writeln!(code, "        }} else {{")?;
5640    writeln!(
5641        code,
5642        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5643    )?;
5644    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5645    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5646    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5647    writeln!(
5648        code,
5649        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5650    )?;
5651    writeln!(
5652        code,
5653        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5654    )?;
5655    writeln!(
5656        code,
5657        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5658    )?;
5659    writeln!(
5660        code,
5661        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5662    )?;
5663    writeln!(
5664        code,
5665        "            let num_tg = (row_tgs * num_tokens) as u64;"
5666    )?;
5667    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5668    writeln!(
5669        code,
5670        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5671    )?;
5672    writeln!(
5673        code,
5674        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5675    )?;
5676    writeln!(code, "        }}")?;
5677    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5678    writeln!(code, "    }}")?;
5679    writeln!(code)?;
5680
5681    // dispatch_matmul_q4_batch
5682    writeln!(
5683        code,
5684        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5685    )?;
5686    writeln!(
5687        code,
5688        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5689    )?;
5690    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5691    writeln!(code, "        let r: u32 = rows as u32;")?;
5692    writeln!(code, "        let c: u32 = cols as u32;")?;
5693    writeln!(
5694        code,
5695        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5696    )?;
5697    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5698    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5699    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5700    writeln!(
5701        code,
5702        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5703    )?;
5704    writeln!(
5705        code,
5706        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5707    )?;
5708    writeln!(
5709        code,
5710        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5711    )?;
5712    writeln!(
5713        code,
5714        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5715    )?;
5716    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5717    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5718    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5719    writeln!(
5720        code,
5721        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5722    )?;
5723    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5724    writeln!(code, "    }}")?;
5725    writeln!(code)?;
5726
5727    // dispatch_add_bias_batch — Qwen2 QKV bias broadcast-add after fused qkv matmul.
5728    if config.qkv_bias {
5729        writeln!(
5730            code,
5731            "    /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5732        )?;
5733        writeln!(
5734            code,
5735            "    fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5736        )?;
5737        writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5738        writeln!(code, "        let r: u32 = rows as u32;")?;
5739        writeln!(
5740            code,
5741            "        enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5742        )?;
5743        writeln!(code, "        enc.set_buffer(0, Some(out), 0);")?;
5744        writeln!(code, "        enc.set_buffer(1, Some(bias), 0);")?;
5745        writeln!(
5746            code,
5747            "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5748        )?;
5749        writeln!(
5750            code,
5751            "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5752        )?;
5753        writeln!(code, "        let total = (num_tokens * rows) as u64;")?;
5754        writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5755        writeln!(
5756            code,
5757            "        let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5758        )?;
5759        writeln!(
5760            code,
5761            "        enc.dispatch_thread_groups(grid_size, tg_size);"
5762        )?;
5763        writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5764        writeln!(code, "    }}")?;
5765        writeln!(code)?;
5766    }
5767
5768    // dispatch_rope_batch
5769    writeln!(
5770        code,
5771        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5772    )?;
5773    writeln!(
5774        code,
5775        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5776    )?;
5777    writeln!(
5778        code,
5779        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5780    )?;
5781    writeln!(
5782        code,
5783        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5784    )?;
5785    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5786    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5787    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5788    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5789    writeln!(
5790        code,
5791        "        let pairs_per_token = num_heads * (head_dim / 2);"
5792    )?;
5793    writeln!(
5794        code,
5795        "        let total_pairs = num_tokens * pairs_per_token;"
5796    )?;
5797    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5798    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5799    // we need to pass the buffer at the right byte offset for each token's data.
5800    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5801    // but our layout is data[token * row_stride + data_float_offset + ...].
5802    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5803    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5804    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5805    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5806    //
5807    // Simplest approach: use the single-token rope kernel for each token in a loop.
5808    // This is still efficient because we're dispatching all within the same command encoder.
5809    writeln!(
5810        code,
5811        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5812    )?;
5813    writeln!(code, "        for t in 0..num_tokens {{")?;
5814    writeln!(
5815        code,
5816        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5817    )?;
5818    writeln!(
5819        code,
5820        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5821    )?;
5822    writeln!(
5823        code,
5824        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5825    )?;
5826    writeln!(
5827        code,
5828        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5829    )?;
5830    writeln!(
5831        code,
5832        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5833    )?;
5834    writeln!(
5835        code,
5836        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5837    )?;
5838    writeln!(
5839        code,
5840        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5841    )?;
5842    writeln!(
5843        code,
5844        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5845    )?;
5846    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5847    writeln!(
5848        code,
5849        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5850    )?;
5851    writeln!(
5852        code,
5853        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5854    )?;
5855    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5856    writeln!(code, "        }}")?;
5857    writeln!(code, "    }}")?;
5858    writeln!(code)?;
5859
5860    // dispatch_silu_mul_fused_batch
5861    writeln!(
5862        code,
5863        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5864    )?;
5865    writeln!(
5866        code,
5867        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5868    )?;
5869    writeln!(code, "        let count: u32 = n as u32;")?;
5870    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5871    writeln!(
5872        code,
5873        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5874    )?;
5875    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5876    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5877    writeln!(
5878        code,
5879        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5880    )?;
5881    writeln!(
5882        code,
5883        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5884    )?;
5885    writeln!(code, "        let total = num_tokens * n;")?;
5886    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5887    writeln!(
5888        code,
5889        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5890    )?;
5891    writeln!(
5892        code,
5893        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5894    )?;
5895    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5896    writeln!(code, "    }}")?;
5897    writeln!(code)?;
5898
5899    // dispatch_add_inplace_batch_n (add n elements in-place)
5900    writeln!(
5901        code,
5902        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5903    )?;
5904    writeln!(
5905        code,
5906        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5907    )?;
5908    writeln!(code, "        let count: u32 = total_n as u32;")?;
5909    writeln!(
5910        code,
5911        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5912    )?;
5913    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5914    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5915    writeln!(
5916        code,
5917        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5918    )?;
5919    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5920    writeln!(
5921        code,
5922        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5923    )?;
5924    writeln!(
5925        code,
5926        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5927    )?;
5928    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5929    writeln!(code, "    }}")?;
5930    writeln!(code)?;
5931
5932    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
5933    writeln!(
5934        code,
5935        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
5936    )?;
5937    writeln!(
5938        code,
5939        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5940    )?;
5941    writeln!(code, "        let n: u32 = count as u32;")?;
5942    writeln!(
5943        code,
5944        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5945    )?;
5946    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5947    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5948    writeln!(
5949        code,
5950        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5951    )?;
5952    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5953    writeln!(
5954        code,
5955        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5956    )?;
5957    writeln!(
5958        code,
5959        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5960    )?;
5961    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5962    writeln!(code, "    }}")?;
5963    writeln!(code)?;
5964
5965    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
5966    writeln!(
5967        code,
5968        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
5969    )?;
5970    writeln!(
5971        code,
5972        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5973    )?;
5974    writeln!(code, "        let n: u32 = count as u32;")?;
5975    writeln!(
5976        code,
5977        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5978    )?;
5979    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5980    writeln!(
5981        code,
5982        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5983    )?;
5984    writeln!(
5985        code,
5986        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5987    )?;
5988    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5989    writeln!(
5990        code,
5991        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5992    )?;
5993    writeln!(
5994        code,
5995        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5996    )?;
5997    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5998    writeln!(code, "    }}")?;
5999    writeln!(code)?;
6000
6001    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
6002    writeln!(
6003        code,
6004        "    /// Copy from src at byte offset to dst at float offset."
6005    )?;
6006    writeln!(
6007        code,
6008        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6009    )?;
6010    writeln!(code, "        let n: u32 = count as u32;")?;
6011    writeln!(
6012        code,
6013        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
6014    )?;
6015    writeln!(
6016        code,
6017        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6018    )?;
6019    writeln!(
6020        code,
6021        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6022    )?;
6023    writeln!(
6024        code,
6025        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6026    )?;
6027    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6028    writeln!(
6029        code,
6030        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6031    )?;
6032    writeln!(
6033        code,
6034        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6035    )?;
6036    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6037    writeln!(code, "    }}")?;
6038    writeln!(code)?;
6039
6040    // dispatch_copy_kv_batch
6041    writeln!(
6042        code,
6043        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6044    )?;
6045    writeln!(
6046        code,
6047        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6048    )?;
6049    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6050    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6051    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6052    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6053    writeln!(code, "        let so: u32 = src_offset as u32;")?;
6054    writeln!(
6055        code,
6056        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6057    )?;
6058    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6059    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
6060    writeln!(
6061        code,
6062        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6063    )?;
6064    writeln!(
6065        code,
6066        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6067    )?;
6068    writeln!(
6069        code,
6070        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6071    )?;
6072    writeln!(
6073        code,
6074        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6075    )?;
6076    writeln!(
6077        code,
6078        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6079    )?;
6080    writeln!(code, "        let total = num_tokens * kv_dim;")?;
6081    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6082    writeln!(
6083        code,
6084        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6085    )?;
6086    writeln!(
6087        code,
6088        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6089    )?;
6090    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6091    writeln!(code, "    }}")?;
6092    writeln!(code)?;
6093
6094    // dispatch_attention_batch
6095    writeln!(
6096        code,
6097        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
6098    )?;
6099    writeln!(
6100        code,
6101        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6102    )?;
6103    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6104    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6105    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
6106    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6107    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6108    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
6109    // Attention kernel selection:
6110    //   * Legacy `attention_batch` materializes scores[4096] in threadgroup memory
6111    //     and uses scalar simdgroup reductions.  Fast at short seq_len, no MMA.
6112    //   * `attention_flash_batch` streams K/V with online softmax; no seq cap,
6113    //     scalar math, ~7-14 % slower than legacy at long contexts (no MMA).
6114    //   * `attention_mma_flash_batch` adds hardware simdgroup_matrix<half, 8, 8>
6115    //     MMA for both Q·K^T and P·V, processing Q_BLOCK=8 tokens per TG.
6116    //     Gated behind FORGE_MMA_ATTN=1 until verified on all supported models.
6117    //     Requires HEAD_DIM ≤ 128 and num_tokens ≥ 8.
6118    writeln!(code, "        let max_seq = base_pos + num_tokens;")?;
6119    writeln!(code, "        let _ = max_seq;")?;
6120    writeln!(
6121        code,
6122        "        let use_mma_flash = std::env::var(\"FORGE_MMA_ATTN\")"
6123    )?;
6124    writeln!(
6125        code,
6126        "            .map(|v| v == \"1\").unwrap_or(false) && HEAD_DIM <= 128 && num_tokens >= 8;"
6127    )?;
6128    writeln!(code, "        if use_mma_flash {{")?;
6129    writeln!(
6130        code,
6131        "            let pipe = &self.attention_mma_flash_batch_pipeline;"
6132    )?;
6133    writeln!(code, "            enc.set_compute_pipeline_state(pipe);")?;
6134    writeln!(code, "            enc.set_buffer(0, Some(q_buf), 0);")?;
6135    writeln!(code, "            enc.set_buffer(1, Some(k_cache), 0);")?;
6136    writeln!(code, "            enc.set_buffer(2, Some(v_cache), 0);")?;
6137    writeln!(code, "            enc.set_buffer(3, Some(output), 0);")?;
6138    writeln!(
6139        code,
6140        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6141    )?;
6142    writeln!(
6143        code,
6144        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6145    )?;
6146    writeln!(
6147        code,
6148        "            enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6149    )?;
6150    writeln!(
6151        code,
6152        "            enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6153    )?;
6154    writeln!(
6155        code,
6156        "            enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6157    )?;
6158    writeln!(
6159        code,
6160        "            enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6161    )?;
6162    writeln!(
6163        code,
6164        "            // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6165    )?;
6166    writeln!(code, "            let tg_size = MTLSize::new(128, 1, 1);")?;
6167    writeln!(
6168        code,
6169        "            let q_blocks = ((num_tokens + 7) / 8) as u64;"
6170    )?;
6171    writeln!(
6172        code,
6173        "            let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6174    )?;
6175    writeln!(
6176        code,
6177        "            enc.dispatch_thread_groups(grid_size, tg_size);"
6178    )?;
6179    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6180    writeln!(code, "            return;")?;
6181    writeln!(code, "        }}")?;
6182    writeln!(code, "        let pipe = &self.attention_batch_pipeline;")?;
6183    writeln!(code, "        enc.set_compute_pipeline_state(pipe);")?;
6184    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
6185    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
6186    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
6187    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
6188    writeln!(
6189        code,
6190        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6191    )?;
6192    writeln!(
6193        code,
6194        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6195    )?;
6196    writeln!(
6197        code,
6198        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6199    )?;
6200    writeln!(
6201        code,
6202        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6203    )?;
6204    writeln!(
6205        code,
6206        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6207    )?;
6208    writeln!(
6209        code,
6210        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6211    )?;
6212    writeln!(
6213        code,
6214        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6215    )?;
6216    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6217    writeln!(
6218        code,
6219        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6220    )?;
6221    writeln!(
6222        code,
6223        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6224    )?;
6225    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6226    writeln!(code, "    }}")?;
6227    writeln!(code)?;
6228
6229    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
6230    writeln!(
6231        code,
6232        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6233    )?;
6234    writeln!(
6235        code,
6236        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6237    )?;
6238    writeln!(
6239        code,
6240        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6241    )?;
6242    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6243    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6244    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
6245    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
6246    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
6247    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
6248    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
6249    writeln!(
6250        code,
6251        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6252    )?;
6253    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
6254    writeln!(
6255        code,
6256        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6257    )?;
6258    writeln!(
6259        code,
6260        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6261    )?;
6262    writeln!(
6263        code,
6264        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6265    )?;
6266    writeln!(
6267        code,
6268        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6269    )?;
6270    writeln!(
6271        code,
6272        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6273    )?;
6274    writeln!(
6275        code,
6276        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6277    )?;
6278    writeln!(
6279        code,
6280        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6281    )?;
6282    writeln!(
6283        code,
6284        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6285    )?;
6286    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6287    writeln!(
6288        code,
6289        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6290    )?;
6291    writeln!(
6292        code,
6293        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6294    )?;
6295    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6296    writeln!(code, "    }}")?;
6297    writeln!(code)?;
6298
6299    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
6300    writeln!(
6301        code,
6302        "    /// Dispatch fused K+V cache copy in one kernel launch."
6303    )?;
6304    writeln!(
6305        code,
6306        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6307    )?;
6308    writeln!(
6309        code,
6310        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6311    )?;
6312    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
6313    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
6314    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
6315    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
6316    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
6317    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
6318    writeln!(
6319        code,
6320        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6321    )?;
6322    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
6323    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
6324    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
6325    writeln!(
6326        code,
6327        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6328    )?;
6329    writeln!(
6330        code,
6331        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6332    )?;
6333    writeln!(
6334        code,
6335        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6336    )?;
6337    writeln!(
6338        code,
6339        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6340    )?;
6341    writeln!(
6342        code,
6343        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6344    )?;
6345    writeln!(
6346        code,
6347        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6348    )?;
6349    writeln!(
6350        code,
6351        "        let total = num_tokens * kv_dim * 2;  // K + V"
6352    )?;
6353    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
6354    writeln!(
6355        code,
6356        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6357    )?;
6358    writeln!(
6359        code,
6360        "        enc.dispatch_thread_groups(grid_size, tg_size);"
6361    )?;
6362    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6363    writeln!(code, "    }}")?;
6364
6365    writeln!(code, "}}")?;
6366    writeln!(code)?;
6367
6368    Ok(())
6369}
6370
6371fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6372    writeln!(
6373        code,
6374        "// ── Helper functions ──────────────────────────────────"
6375    )?;
6376    writeln!(code)?;
6377    writeln!(
6378        code,
6379        "/// Create a compute pipeline from a named function in the Metal library."
6380    )?;
6381    writeln!(
6382        code,
6383        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6384    )?;
6385    writeln!(
6386        code,
6387        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6388    )?;
6389    writeln!(
6390        code,
6391        "    device.new_compute_pipeline_state_with_function(&func)"
6392    )?;
6393    writeln!(
6394        code,
6395        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6396    )?;
6397    writeln!(code, "}}")?;
6398    writeln!(code)?;
6399
6400    Ok(())
6401}
6402
6403// ---------------------------------------------------------------------------
6404// main.rs generation
6405// ---------------------------------------------------------------------------
6406
6407fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6408    let _sanitized = sanitize_name(model_name);
6409    let _vocab = config.vocab_size;
6410
6411    let mut code = String::with_capacity(16 * 1024);
6412    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6413    writeln!(
6414        code,
6415        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6416    )?;
6417    writeln!(code)?;
6418    writeln!(code, "mod model;")?;
6419    writeln!(code)?;
6420    writeln!(code, "use std::io::Write;")?;
6421    writeln!(code, "use std::time::Instant;")?;
6422    writeln!(code, "use serde::Deserialize;")?;
6423    writeln!(code)?;
6424
6425    // -- main function --
6426    writeln!(code, "fn main() {{")?;
6427    writeln!(
6428        code,
6429        "    let args: Vec<String> = std::env::args().collect();"
6430    )?;
6431    writeln!(code)?;
6432    writeln!(
6433        code,
6434        "    // Detect --serve mode (only requires weights + tokenizer)"
6435    )?;
6436    writeln!(
6437        code,
6438        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
6439    )?;
6440    writeln!(code)?;
6441    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
6442    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6443    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6444    writeln!(code, "        std::process::exit(1);")?;
6445    writeln!(code, "    }}")?;
6446    writeln!(code)?;
6447    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
6448    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6449    writeln!(code, "        std::process::exit(1);")?;
6450    writeln!(code, "    }}")?;
6451    writeln!(code)?;
6452    writeln!(code, "    let weights_path = &args[1];")?;
6453    writeln!(code, "    let tokenizer_path = &args[2];")?;
6454    writeln!(code)?;
6455    writeln!(code, "    // Parse optional flags")?;
6456    writeln!(code, "    let mut max_tokens: usize = 128;")?;
6457    writeln!(code, "    let mut port: u16 = 8080;")?;
6458    writeln!(
6459        code,
6460        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6461    )?;
6462    writeln!(
6463        code,
6464        "    let profile = args.iter().any(|a| a == \"--profile\");"
6465    )?;
6466    writeln!(code, "    let mut i = 3;")?;
6467    writeln!(code, "    while i < args.len() {{")?;
6468    writeln!(
6469        code,
6470        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6471    )?;
6472    writeln!(
6473        code,
6474        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
6475    )?;
6476    writeln!(code, "            i += 2;")?;
6477    writeln!(
6478        code,
6479        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6480    )?;
6481    writeln!(
6482        code,
6483        "            port = args[i + 1].parse().unwrap_or(8080);"
6484    )?;
6485    writeln!(code, "            i += 2;")?;
6486    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
6487    writeln!(code, "            i += 1;")?;
6488    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
6489    writeln!(code, "            i += 1;")?;
6490    writeln!(code, "        }} else {{")?;
6491    writeln!(code, "            i += 1;")?;
6492    writeln!(code, "        }}")?;
6493    writeln!(code, "    }}")?;
6494    writeln!(code)?;
6495
6496    // -- load model (shared by both modes) --
6497    writeln!(
6498        code,
6499        "    // Memory-map weights for zero-copy loading on Apple Silicon"
6500    )?;
6501    writeln!(
6502        code,
6503        "    let weights_file = std::fs::File::open(weights_path)"
6504    )?;
6505    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6506    writeln!(
6507        code,
6508        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6509    )?;
6510    writeln!(code)?;
6511    writeln!(code, "    // Load tokenizer")?;
6512    writeln!(
6513        code,
6514        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6515    )?;
6516    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6517    writeln!(code)?;
6518    writeln!(code, "    // Create Metal model")?;
6519    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
6520    writeln!(
6521        code,
6522        "    let mut model = model::MetalModel::new(&weights_mmap);"
6523    )?;
6524    writeln!(code)?;
6525
6526    // -- branch: serve vs CLI --
6527    writeln!(code, "    if serve_mode {{")?;
6528    writeln!(code, "        serve(model, tokenizer, port);")?;
6529    writeln!(code, "    }} else {{")?;
6530    writeln!(code, "        let prompt = &args[3];")?;
6531    writeln!(
6532        code,
6533        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6534    )?;
6535    writeln!(code, "    }}")?;
6536    writeln!(code, "}}")?;
6537    writeln!(code)?;
6538
6539    // -- cli_mode function --
6540    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6541    writeln!(code, "    // Tokenize prompt")?;
6542    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
6543    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6544    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6545    writeln!(code)?;
6546    writeln!(
6547        code,
6548        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6549    )?;
6550    writeln!(
6551        code,
6552        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
6553    )?;
6554    writeln!(
6555        code,
6556        "    // The last token uses synchronous forward() to get logits."
6557    )?;
6558    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
6559    writeln!(code, "    let prefill_start = Instant::now();")?;
6560    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
6561    writeln!(
6562        code,
6563        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6564    )?;
6565    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
6566    writeln!(code, "    }} else {{")?;
6567    writeln!(code, "        model.forward(prompt_tokens[0])")?;
6568    writeln!(code, "    }};")?;
6569    writeln!(
6570        code,
6571        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6572    )?;
6573    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
6574    writeln!(
6575        code,
6576        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6577    )?;
6578    writeln!(
6579        code,
6580        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6581    )?;
6582    writeln!(code)?;
6583    writeln!(code, "    // Generate tokens")?;
6584    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6585    writeln!(code, "    let gen_start = Instant::now();")?;
6586    writeln!(code, "    let mut generated_count: usize = 0;")?;
6587    writeln!(code)?;
6588    writeln!(code, "    for _ in 0..max_tokens {{")?;
6589    writeln!(
6590        code,
6591        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6592    )?;
6593    writeln!(code, "            if !quiet {{")?;
6594    writeln!(code, "                print!(\"{{}}\", text);")?;
6595    writeln!(code, "                std::io::stdout().flush().ok();")?;
6596    writeln!(code, "            }}")?;
6597    writeln!(code, "        }}")?;
6598    writeln!(code, "        generated_count += 1;")?;
6599    writeln!(code)?;
6600    writeln!(
6601        code,
6602        "        // Use profiling forward for first token when --profile is set"
6603    )?;
6604    writeln!(
6605        code,
6606        "        let logits = if profile && generated_count == 1 {{"
6607    )?;
6608    writeln!(code, "            model.forward_profile(next_token)")?;
6609    writeln!(code, "        }} else {{")?;
6610    writeln!(code, "            model.forward(next_token)")?;
6611    writeln!(code, "        }};")?;
6612    writeln!(code, "        next_token = argmax(&logits);")?;
6613    writeln!(code)?;
6614    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
6615    writeln!(code, "        if next_token == 2 {{")?;
6616    writeln!(code, "            break;")?;
6617    writeln!(code, "        }}")?;
6618    writeln!(code)?;
6619    writeln!(
6620        code,
6621        "        // Yield between tokens to reduce sustained GPU thermal load."
6622    )?;
6623    writeln!(
6624        code,
6625        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6626    )?;
6627    writeln!(
6628        code,
6629        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6630    )?;
6631    writeln!(
6632        code,
6633        "        // briefly, providing a micro-break that helps sustain peak throughput."
6634    )?;
6635    writeln!(code, "        std::thread::yield_now();")?;
6636    writeln!(code, "    }}")?;
6637    writeln!(code, "    if !quiet {{")?;
6638    writeln!(code, "        println!();")?;
6639    writeln!(code, "    }}")?;
6640    writeln!(
6641        code,
6642        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6643    )?;
6644    writeln!(
6645        code,
6646        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6647    )?;
6648    writeln!(
6649        code,
6650        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6651    )?;
6652    writeln!(code, "}}")?;
6653    writeln!(code)?;
6654
6655    // -- argmax helper --
6656    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6657    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6658    writeln!(code, "    logits.iter()")?;
6659    writeln!(code, "        .enumerate()")?;
6660    writeln!(
6661        code,
6662        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6663    )?;
6664    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6665    writeln!(code, "        .unwrap_or(0)")?;
6666    writeln!(code, "}}")?;
6667    writeln!(code)?;
6668
6669    // -- Request/Response types for OpenAI API --
6670    writeln!(
6671        code,
6672        "// -----------------------------------------------------------------------"
6673    )?;
6674    writeln!(code, "// OpenAI-compatible API server")?;
6675    writeln!(
6676        code,
6677        "// -----------------------------------------------------------------------"
6678    )?;
6679    writeln!(code)?;
6680    writeln!(code, "#[derive(Deserialize)]")?;
6681    writeln!(code, "struct ChatRequest {{")?;
6682    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6683    writeln!(code, "    #[serde(default)]")?;
6684    writeln!(code, "    stream: Option<bool>,")?;
6685    writeln!(code, "    #[serde(default)]")?;
6686    writeln!(code, "    max_tokens: Option<usize>,")?;
6687    writeln!(code, "    #[serde(default)]")?;
6688    writeln!(code, "    temperature: Option<f32>,")?;
6689    writeln!(code, "    #[serde(default)]")?;
6690    writeln!(code, "    model: Option<String>,")?;
6691    writeln!(code, "}}")?;
6692    writeln!(code)?;
6693    writeln!(code, "#[derive(Deserialize)]")?;
6694    writeln!(code, "struct ChatMessage {{")?;
6695    writeln!(code, "    role: String,")?;
6696    writeln!(code, "    content: String,")?;
6697    writeln!(code, "}}")?;
6698    writeln!(code)?;
6699
6700    // -- format_chat_messages --
6701    writeln!(
6702        code,
6703        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6704    )?;
6705    writeln!(code, "    let mut prompt = String::new();")?;
6706    writeln!(code, "    for msg in messages {{")?;
6707    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6708    writeln!(code, "    }}")?;
6709    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6710    writeln!(code, "    prompt")?;
6711    writeln!(code, "}}")?;
6712    writeln!(code)?;
6713
6714    // -- prefill helper --
6715    writeln!(
6716        code,
6717        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6718    )?;
6719    writeln!(code, "    let len = tokens.len();")?;
6720    writeln!(code, "    if len > 1 {{")?;
6721    writeln!(
6722        code,
6723        "        model.forward_prefill_batch(&tokens[..len - 1]);"
6724    )?;
6725    writeln!(code, "    }}")?;
6726    writeln!(code, "    model.forward(tokens[len - 1])")?;
6727    writeln!(code, "}}")?;
6728    writeln!(code)?;
6729
6730    // -- serve function --
6731    writeln!(
6732        code,
6733        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6734    )?;
6735    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6736    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
6737    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6738    writeln!(
6739        code,
6740        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6741    )?;
6742    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
6743    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
6744    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
6745    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
6746    writeln!(code)?;
6747    writeln!(code, "    for request in server.incoming_requests() {{")?;
6748    writeln!(code, "        let method = request.method().to_string();")?;
6749    writeln!(code, "        let url = request.url().to_string();")?;
6750    writeln!(code)?;
6751    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
6752
6753    // -- POST /v1/chat/completions --
6754    writeln!(
6755        code,
6756        "            (\"POST\", \"/v1/chat/completions\") => {{"
6757    )?;
6758    writeln!(
6759        code,
6760        "                handle_chat_completion(&mut model, &tokenizer, request);"
6761    )?;
6762    writeln!(code, "            }}")?;
6763
6764    // -- GET /v1/models --
6765    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
6766    writeln!(code, "                let body = serde_json::json!({{")?;
6767    writeln!(code, "                    \"object\": \"list\",")?;
6768    writeln!(code, "                    \"data\": [{{")?;
6769    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
6770    writeln!(code, "                        \"object\": \"model\",")?;
6771    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
6772    writeln!(code, "                    }}]")?;
6773    writeln!(code, "                }});")?;
6774    writeln!(
6775        code,
6776        "                let resp = tiny_http::Response::from_string(body.to_string())"
6777    )?;
6778    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6779    writeln!(code, "                request.respond(resp).ok();")?;
6780    writeln!(code, "            }}")?;
6781
6782    // -- GET /health --
6783    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
6784    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6785    writeln!(code, "                request.respond(resp).ok();")?;
6786    writeln!(code, "            }}")?;
6787
6788    // -- 404 --
6789    writeln!(code, "            _ => {{")?;
6790    writeln!(
6791        code,
6792        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
6793    )?;
6794    writeln!(code, "                    .with_status_code(404);")?;
6795    writeln!(code, "                request.respond(resp).ok();")?;
6796    writeln!(code, "            }}")?;
6797    writeln!(code, "        }}")?;
6798    writeln!(code, "    }}")?;
6799    writeln!(code, "}}")?;
6800    writeln!(code)?;
6801
6802    // -- handle_chat_completion --
6803    writeln!(code, "fn handle_chat_completion(")?;
6804    writeln!(code, "    model: &mut model::MetalModel,")?;
6805    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6806    writeln!(code, "    mut request: tiny_http::Request,")?;
6807    writeln!(code, ") {{")?;
6808    writeln!(code, "    // Read request body")?;
6809    writeln!(code, "    let mut body = String::new();")?;
6810    writeln!(
6811        code,
6812        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6813    )?;
6814    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6815    writeln!(code, "            .with_status_code(400);")?;
6816    writeln!(code, "        request.respond(resp).ok();")?;
6817    writeln!(code, "        return;")?;
6818    writeln!(code, "    }}")?;
6819    writeln!(code)?;
6820    writeln!(code, "    // Parse JSON")?;
6821    writeln!(
6822        code,
6823        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6824    )?;
6825    writeln!(code, "        Ok(r) => r,")?;
6826    writeln!(code, "        Err(e) => {{")?;
6827    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6828    writeln!(code, "                .with_status_code(400);")?;
6829    writeln!(code, "            request.respond(resp).ok();")?;
6830    writeln!(code, "            return;")?;
6831    writeln!(code, "        }}")?;
6832    writeln!(code, "    }};")?;
6833    writeln!(code)?;
6834    writeln!(
6835        code,
6836        "    let prompt = format_chat_messages(&req.messages);"
6837    )?;
6838    writeln!(
6839        code,
6840        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6841    )?;
6842    writeln!(code, "        Ok(e) => e,")?;
6843    writeln!(code, "        Err(e) => {{")?;
6844    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6845    writeln!(code, "                .with_status_code(500);")?;
6846    writeln!(code, "            request.respond(resp).ok();")?;
6847    writeln!(code, "            return;")?;
6848    writeln!(code, "        }}")?;
6849    writeln!(code, "    }};")?;
6850    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6851    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6852    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6853    writeln!(
6854        code,
6855        "    let _temperature = req.temperature.unwrap_or(1.0);"
6856    )?;
6857    writeln!(code)?;
6858
6859    // -- Reset KV cache for each request --
6860    writeln!(code, "    model.reset();")?;
6861    writeln!(code)?;
6862
6863    // -- Prefill with timing --
6864    writeln!(code, "    let prefill_start = Instant::now();")?;
6865    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6866    writeln!(
6867        code,
6868        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6869    )?;
6870    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6871    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6872    writeln!(code)?;
6873
6874    writeln!(code, "    if stream {{")?;
6875
6876    // -- SSE streaming response --
6877    writeln!(
6878        code,
6879        "        // SSE streaming: generate tokens and build SSE body"
6880    )?;
6881    writeln!(code, "        let gen_start = Instant::now();")?;
6882    writeln!(code, "        let mut generated_count: usize = 0;")?;
6883    writeln!(code, "        let mut sse_body = String::new();")?;
6884    writeln!(code, "        for _ in 0..max_tokens {{")?;
6885    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6886    writeln!(
6887        code,
6888        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6889    )?;
6890    writeln!(
6891        code,
6892        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6893    )?;
6894    writeln!(
6895        code,
6896        "                // escaped includes surrounding quotes, strip them"
6897    )?;
6898    writeln!(
6899        code,
6900        "                let inner = &escaped[1..escaped.len()-1];"
6901    )?;
6902    writeln!(code, "                sse_body.push_str(&format!(")?;
6903    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6904    writeln!(code, "                    inner")?;
6905    writeln!(code, "                ));")?;
6906    writeln!(code, "            }}")?;
6907    writeln!(code, "            generated_count += 1;")?;
6908    writeln!(code, "            let logits = model.forward(next_token);")?;
6909    writeln!(code, "            next_token = argmax(&logits);")?;
6910    writeln!(code, "        }}")?;
6911    writeln!(
6912        code,
6913        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6914    )?;
6915    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6916    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
6917    writeln!(code)?;
6918    writeln!(
6919        code,
6920        "        // Final chunk with finish_reason, timing, and DONE sentinel"
6921    )?;
6922    writeln!(code, "        sse_body.push_str(&format!(")?;
6923    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6924    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6925    writeln!(code, "        ));")?;
6926    writeln!(code)?;
6927    writeln!(
6928        code,
6929        "        let resp = tiny_http::Response::from_string(sse_body)"
6930    )?;
6931    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6932    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6933    writeln!(code, "        request.respond(resp).ok();")?;
6934
6935    writeln!(code, "    }} else {{")?;
6936
6937    // -- Non-streaming response --
6938    writeln!(
6939        code,
6940        "        // Non-streaming: generate all tokens, return JSON"
6941    )?;
6942    writeln!(code, "        let gen_start = Instant::now();")?;
6943    writeln!(code, "        let mut generated_count: usize = 0;")?;
6944    writeln!(code, "        let mut generated = String::new();")?;
6945    writeln!(code, "        for _ in 0..max_tokens {{")?;
6946    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6947    writeln!(
6948        code,
6949        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6950    )?;
6951    writeln!(code, "                generated.push_str(&text);")?;
6952    writeln!(code, "            }}")?;
6953    writeln!(code, "            generated_count += 1;")?;
6954    writeln!(code, "            let logits = model.forward(next_token);")?;
6955    writeln!(code, "            next_token = argmax(&logits);")?;
6956    writeln!(code, "        }}")?;
6957    writeln!(
6958        code,
6959        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6960    )?;
6961    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6962    writeln!(code)?;
6963    writeln!(code, "        let resp_json = serde_json::json!({{")?;
6964    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
6965    writeln!(code, "            \"object\": \"chat.completion\",")?;
6966    writeln!(code, "            \"choices\": [{{")?;
6967    writeln!(code, "                \"index\": 0,")?;
6968    writeln!(code, "                \"message\": {{")?;
6969    writeln!(code, "                    \"role\": \"assistant\",")?;
6970    writeln!(code, "                    \"content\": generated")?;
6971    writeln!(code, "                }},")?;
6972    writeln!(code, "                \"finish_reason\": \"stop\"")?;
6973    writeln!(code, "            }}],")?;
6974    writeln!(code, "            \"usage\": {{")?;
6975    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
6976    writeln!(
6977        code,
6978        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
6979    )?;
6980    writeln!(
6981        code,
6982        "                \"generation_tokens\": generated_count,"
6983    )?;
6984    writeln!(
6985        code,
6986        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
6987    )?;
6988    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
6989    writeln!(code, "            }}")?;
6990    writeln!(code, "        }});")?;
6991    writeln!(
6992        code,
6993        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
6994    )?;
6995    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6996    writeln!(code, "        request.respond(resp).ok();")?;
6997    writeln!(code, "    }}")?;
6998    writeln!(code, "}}")?;
6999
7000    Ok(code)
7001}
7002
7003// ---------------------------------------------------------------------------
7004// Tests
7005// ---------------------------------------------------------------------------
7006
7007#[cfg(test)]
7008mod tests {
7009    use super::*;
7010    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7011
7012    fn minimal_config() -> ModelConfig {
7013        ModelConfig {
7014            architecture: Architecture::Llama,
7015            hidden_size: 64,
7016            intermediate_size: 128,
7017            num_layers: 2,
7018            num_attention_heads: 4,
7019            num_kv_heads: 4,
7020            head_dim: 16,
7021            vocab_size: 256,
7022            max_seq_len: 512,
7023            rms_norm_eps: 1e-5,
7024            rope_theta: 10000.0,
7025            dtype: DType::F32,
7026            sliding_window_size: None,
7027            qkv_bias: false,
7028        }
7029    }
7030
7031    fn minimal_graph() -> Graph {
7032        Graph::new("test-metal").with_config(minimal_config())
7033    }
7034
7035    #[test]
7036    fn generate_metal_project_creates_files() {
7037        let dir = tempfile::tempdir().unwrap();
7038        let graph = minimal_graph();
7039        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7040
7041        assert!(
7042            dir.path().join("Cargo.toml").exists(),
7043            "Cargo.toml should be created"
7044        );
7045        assert!(
7046            dir.path().join("src/model.rs").exists(),
7047            "src/model.rs should be created"
7048        );
7049        assert!(
7050            dir.path().join("src/main.rs").exists(),
7051            "src/main.rs should be created"
7052        );
7053        assert!(
7054            dir.path().join("shaders/kernels.metal").exists(),
7055            "shaders/kernels.metal should be created"
7056        );
7057    }
7058
7059    #[test]
7060    fn generated_cargo_toml_has_metal_dep() {
7061        let toml = generate_cargo_toml("my-model");
7062        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7063        assert!(
7064            toml.contains("tokenizers"),
7065            "Cargo.toml should depend on tokenizers"
7066        );
7067        assert!(
7068            toml.contains("memmap2"),
7069            "Cargo.toml should depend on memmap2"
7070        );
7071        assert!(toml.contains("half"), "Cargo.toml should depend on half");
7072    }
7073
7074    #[test]
7075    fn generated_model_rs_contains_metal_code() {
7076        let config = minimal_config();
7077        let model_rs = generate_model_rs(&config).unwrap();
7078
7079        assert!(
7080            model_rs.contains("pub struct MetalModel"),
7081            "model.rs should define MetalModel struct"
7082        );
7083        assert!(
7084            model_rs.contains("matmul_pipeline: ComputePipelineState"),
7085            "MetalModel should have matmul_pipeline field"
7086        );
7087        assert!(
7088            model_rs.contains("Device::system_default()"),
7089            "model.rs should use Metal device"
7090        );
7091        assert!(
7092            model_rs.contains("new_library_with_source"),
7093            "model.rs should compile Metal shaders"
7094        );
7095        assert!(
7096            model_rs.contains("fn new(weights: &[u8])"),
7097            "MetalModel should implement new()"
7098        );
7099        assert!(
7100            model_rs.contains("fn forward(&mut self, token_id: u32)"),
7101            "MetalModel should implement forward()"
7102        );
7103    }
7104
7105    #[test]
7106    fn generated_shaders_contain_kernel_names() {
7107        let shaders = generate_metal_shaders(&minimal_config());
7108
7109        assert!(
7110            shaders.contains("kernel void matmul_vec"),
7111            "shaders should contain matmul_vec kernel"
7112        );
7113        assert!(
7114            shaders.contains("kernel void rms_norm"),
7115            "shaders should contain rms_norm kernel"
7116        );
7117        assert!(
7118            shaders.contains("kernel void rope"),
7119            "shaders should contain rope kernel"
7120        );
7121        assert!(
7122            shaders.contains("kernel void softmax"),
7123            "shaders should contain softmax kernel"
7124        );
7125        assert!(
7126            shaders.contains("kernel void silu_mul("),
7127            "shaders should contain silu_mul kernel"
7128        );
7129        assert!(
7130            shaders.contains("kernel void silu_mul_fused"),
7131            "shaders should contain silu_mul_fused kernel"
7132        );
7133        assert!(
7134            shaders.contains("kernel void elementwise_add"),
7135            "shaders should contain elementwise_add kernel"
7136        );
7137        assert!(
7138            shaders.contains("kernel void attention"),
7139            "shaders should contain attention kernel"
7140        );
7141        assert!(
7142            shaders.contains("kernel void add_inplace"),
7143            "shaders should contain add_inplace kernel"
7144        );
7145        assert!(
7146            shaders.contains("kernel void copy_buffer"),
7147            "shaders should contain copy_buffer kernel"
7148        );
7149        assert!(
7150            shaders.contains("kernel void copy_offset"),
7151            "shaders should contain copy_offset kernel"
7152        );
7153    }
7154
7155    #[test]
7156    fn generated_shaders_use_simdgroup_features() {
7157        let shaders = generate_metal_shaders(&minimal_config());
7158
7159        assert!(
7160            shaders.contains("threadgroup_barrier"),
7161            "shaders should use threadgroup barriers"
7162        );
7163        assert!(
7164            shaders.contains("threadgroup float"),
7165            "shaders should use threadgroup shared memory"
7166        );
7167        assert!(
7168            shaders.contains("thread_index_in_threadgroup"),
7169            "shaders should use threadgroup indexing"
7170        );
7171        assert!(
7172            shaders.contains("simd_sum"),
7173            "shaders should use simd_sum for warp-level reduction"
7174        );
7175        assert!(
7176            shaders.contains("simd_max"),
7177            "attention kernel should use simd_max for cooperative softmax"
7178        );
7179        assert!(
7180            shaders.contains("thread_index_in_simdgroup"),
7181            "shaders should use simdgroup lane indexing"
7182        );
7183        assert!(
7184            shaders.contains("simdgroup_index_in_threadgroup"),
7185            "shaders should use simdgroup indexing within threadgroup"
7186        );
7187        assert!(
7188            shaders.contains("float4"),
7189            "matmul_vec should use float4 vectorized loads"
7190        );
7191    }
7192
7193    #[test]
7194    fn generated_main_rs_has_tokenizer_usage() {
7195        let config = minimal_config();
7196        let main_rs = generate_main_rs("test-model", &config).unwrap();
7197
7198        assert!(
7199            main_rs.contains("tokenizers::Tokenizer"),
7200            "main.rs should use tokenizers crate"
7201        );
7202        assert!(
7203            main_rs.contains("MetalModel::new"),
7204            "main.rs should call MetalModel::new"
7205        );
7206        assert!(
7207            main_rs.contains("model.forward"),
7208            "main.rs should call model.forward"
7209        );
7210        assert!(
7211            main_rs.contains("memmap2"),
7212            "main.rs should use memmap2 for zero-copy weight loading"
7213        );
7214    }
7215
7216    #[test]
7217    fn missing_config_returns_error() {
7218        let dir = tempfile::tempdir().unwrap();
7219        let graph = Graph::new("no-config");
7220        let result = generate_metal_project(&graph, dir.path(), "fail");
7221        assert!(
7222            matches!(result, Err(MetalCodegenError::MissingConfig)),
7223            "should fail with MissingConfig when graph has no config"
7224        );
7225    }
7226
7227    #[test]
7228    fn sanitize_name_works() {
7229        assert_eq!(sanitize_name("My Model!"), "my-model");
7230        assert_eq!(sanitize_name("test_model"), "test-model");
7231        assert_eq!(sanitize_name("simple"), "simple");
7232    }
7233
7234    #[test]
7235    fn generated_forward_uses_single_command_buffer() {
7236        let config = minimal_config();
7237        let model_rs = generate_model_rs(&config).unwrap();
7238
7239        // The forward function should create exactly one command buffer.
7240        // Use the exact signature to avoid matching forward_prefill/forward_profile.
7241        let forward_start = model_rs
7242            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7243            .unwrap();
7244        let forward_body = &model_rs[forward_start..];
7245        // End at the next pub/private method
7246        let forward_end = forward_body
7247            .find("\n    pub fn forward_profile")
7248            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7249            .or_else(|| forward_body.find("\n    fn dispatch_"))
7250            .unwrap_or(forward_body.len());
7251        let forward_code = &forward_body[..forward_end];
7252
7253        // Should have exactly one new_command_buffer call
7254        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7255        assert_eq!(
7256            cmd_buf_count, 1,
7257            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7258        );
7259
7260        // Should have exactly one commit call
7261        let commit_count = forward_code.matches("cmd.commit()").count();
7262        assert_eq!(
7263            commit_count, 1,
7264            "forward() should commit exactly once, found {commit_count}"
7265        );
7266
7267        // Should wait: once for cmd + possibly once for prev_cmd drain
7268        let wait_count = forward_code.matches("wait_until_completed()").count();
7269        assert!(
7270            wait_count >= 1 && wait_count <= 2,
7271            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7272        );
7273    }
7274
7275    #[test]
7276    fn generated_model_has_preallocated_working_buffers() {
7277        let config = minimal_config();
7278        let model_rs = generate_model_rs(&config).unwrap();
7279
7280        for buf_name in &[
7281            "normed_buf",
7282            "qkv_buf",
7283            "attn_out_buf",
7284            "attn_proj_buf",
7285            "gate_up_buf",
7286            "ffn_hidden_buf",
7287            "ffn_out_buf",
7288            "add_tmp_buf",
7289        ] {
7290            assert!(
7291                model_rs.contains(&format!("{buf_name}: Buffer")),
7292                "MetalModel should have pre-allocated {buf_name} field"
7293            );
7294        }
7295    }
7296
7297    #[test]
7298    fn generated_dispatch_helpers_take_compute_encoder_ref() {
7299        let config = minimal_config();
7300        let model_rs = generate_model_rs(&config).unwrap();
7301
7302        for method in &[
7303            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7304            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7305            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7306            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7307            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7308            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7309            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7310            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7311            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7312            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7313            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7314            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7315            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7316        ] {
7317            assert!(
7318                model_rs.contains(method),
7319                "model.rs should contain dispatch helper: {method}"
7320            );
7321        }
7322    }
7323
7324    #[test]
7325    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7326        let config = minimal_config();
7327        let model_rs = generate_model_rs(&config).unwrap();
7328
7329        // Find dispatch helpers section and check none create their own encoders
7330        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7331        let helpers_code = &model_rs[helpers_start..];
7332
7333        // None of the dispatch_ helpers should call new_command_buffer
7334        assert!(
7335            !helpers_code.contains("self.queue.new_command_buffer()"),
7336            "dispatch helpers should not create their own command buffers"
7337        );
7338
7339        // None should create their own compute encoders
7340        assert!(
7341            !helpers_code.contains("new_compute_command_encoder()"),
7342            "dispatch helpers should not create their own compute encoders"
7343        );
7344
7345        // None should call end_encoding
7346        assert!(
7347            !helpers_code.contains("end_encoding()"),
7348            "dispatch helpers should not call end_encoding"
7349        );
7350
7351        // None should call commit or wait
7352        assert!(
7353            !helpers_code.contains(".commit()"),
7354            "dispatch helpers should not commit command buffers"
7355        );
7356        assert!(
7357            !helpers_code.contains("wait_until_completed"),
7358            "dispatch helpers should not wait on command buffers"
7359        );
7360    }
7361
7362    #[test]
7363    fn generated_forward_batches_compute_encoders() {
7364        let config = minimal_config();
7365        let model_rs = generate_model_rs(&config).unwrap();
7366
7367        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
7368        let forward_start = model_rs
7369            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7370            .unwrap();
7371        let forward_body = &model_rs[forward_start..];
7372        let forward_end = forward_body
7373            .find("\n    pub fn forward_profile")
7374            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7375            .or_else(|| forward_body.find("\n    fn dispatch_"))
7376            .unwrap_or(forward_body.len());
7377        let forward_code = &forward_body[..forward_end];
7378
7379        // Forward should not allocate new buffers
7380        assert!(
7381            !forward_code.contains("device.new_buffer"),
7382            "forward() should not allocate new buffers per call"
7383        );
7384
7385        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
7386        // Copy operations use compute copy kernels instead of blit encoders.
7387        let compute_encoder_count = forward_code
7388            .matches("new_compute_command_encoder()")
7389            .count();
7390        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7391
7392        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
7393        assert_eq!(
7394            compute_encoder_count, 1,
7395            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7396        );
7397        assert_eq!(
7398            blit_encoder_count, 0,
7399            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7400        );
7401    }
7402
7403    #[test]
7404    fn generated_forward_uses_add_inplace() {
7405        let config = minimal_config();
7406        let model_rs = generate_model_rs(&config).unwrap();
7407
7408        // Should use in-place add (no blit copy-back needed)
7409        assert!(
7410            model_rs.contains("dispatch_add_inplace"),
7411            "forward() should use dispatch_add_inplace for residual connections"
7412        );
7413        assert!(
7414            model_rs.contains("add_inplace_pipeline"),
7415            "MetalModel should have add_inplace_pipeline"
7416        );
7417    }
7418
7419    fn minimal_q8_config() -> ModelConfig {
7420        ModelConfig {
7421            architecture: Architecture::Llama,
7422            hidden_size: 64,
7423            intermediate_size: 128,
7424            num_layers: 2,
7425            num_attention_heads: 4,
7426            num_kv_heads: 4,
7427            head_dim: 16,
7428            vocab_size: 256,
7429            max_seq_len: 512,
7430            rms_norm_eps: 1e-5,
7431            rope_theta: 10000.0,
7432            dtype: DType::Q8_0,
7433            sliding_window_size: None,
7434            qkv_bias: false,
7435        }
7436    }
7437
7438    #[test]
7439    fn generated_shaders_contain_q8_kernel() {
7440        let shaders = generate_metal_shaders(&minimal_config());
7441
7442        assert!(
7443            shaders.contains("kernel void matmul_vec_q8"),
7444            "shaders should contain matmul_vec_q8 kernel"
7445        );
7446        assert!(
7447            shaders.contains("device const uchar* matrix"),
7448            "matmul_vec_q8 should accept raw Q8_0 bytes"
7449        );
7450        assert!(
7451            shaders.contains("packed_short4"),
7452            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7453        );
7454        assert!(
7455            shaders.contains("as_type<char2>"),
7456            "matmul_vec_q8 should bitcast short lanes to char2"
7457        );
7458        assert!(
7459            shaders.contains("device const half*"),
7460            "matmul_vec_q8 should read f16 scale via half pointer"
7461        );
7462    }
7463
7464    #[test]
7465    fn generated_model_uses_fused_qkv_projections() {
7466        let config = minimal_config();
7467        let model_rs = generate_model_rs(&config).unwrap();
7468
7469        // Should have fused QKV weight in layer buffers
7470        assert!(
7471            model_rs.contains("qkv_weight: Buffer"),
7472            "LayerBuffers should have fused qkv_weight field"
7473        );
7474        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
7475        assert!(
7476            !model_rs.contains("    q_weight: Buffer"),
7477            "LayerBuffers should not have separate q_weight field"
7478        );
7479        assert!(
7480            !model_rs.contains("    k_weight: Buffer"),
7481            "LayerBuffers should not have separate k_weight field"
7482        );
7483        assert!(
7484            !model_rs.contains("    v_weight: Buffer"),
7485            "LayerBuffers should not have separate v_weight field"
7486        );
7487
7488        // Should have fused gate_up_weight
7489        assert!(
7490            model_rs.contains("gate_up_weight: Buffer"),
7491            "LayerBuffers should have fused gate_up_weight field"
7492        );
7493        // Should NOT have separate gate/up weight fields
7494        assert!(
7495            !model_rs.contains("    gate_weight: Buffer"),
7496            "LayerBuffers should not have separate gate_weight field"
7497        );
7498        assert!(
7499            !model_rs.contains("    up_weight: Buffer"),
7500            "LayerBuffers should not have separate up_weight field"
7501        );
7502
7503        // Should have fused working buffers
7504        assert!(
7505            model_rs.contains("qkv_buf: Buffer"),
7506            "MetalModel should have fused qkv_buf"
7507        );
7508        assert!(
7509            model_rs.contains("gate_up_buf: Buffer"),
7510            "MetalModel should have fused gate_up_buf"
7511        );
7512
7513        // Forward pass should use fused dispatch
7514        assert!(
7515            model_rs.contains("dispatch_silu_mul_fused"),
7516            "forward pass should use dispatch_silu_mul_fused"
7517        );
7518        assert!(
7519            model_rs.contains("dispatch_rope_offset"),
7520            "forward pass should use dispatch_rope_offset for fused QKV"
7521        );
7522        assert!(
7523            model_rs.contains("dispatch_attention_offset"),
7524            "forward pass should use dispatch_attention_offset for fused QKV"
7525        );
7526    }
7527
7528    #[test]
7529    fn q8_model_has_matmul_q8_pipeline() {
7530        let config = minimal_q8_config();
7531        let model_rs = generate_model_rs(&config).unwrap();
7532
7533        assert!(
7534            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7535            "MetalModel should have matmul_q8_pipeline field"
7536        );
7537        assert!(
7538            model_rs.contains("matmul_q8_pipeline,"),
7539            "MetalModel Self should include matmul_q8_pipeline"
7540        );
7541    }
7542
7543    #[test]
7544    fn q8_model_uses_dispatch_matmul_q8() {
7545        let config = minimal_q8_config();
7546        let model_rs = generate_model_rs(&config).unwrap();
7547
7548        assert!(
7549            model_rs.contains("dispatch_matmul_q8"),
7550            "Q8_0 model should use dispatch_matmul_q8 for projections"
7551        );
7552        assert!(
7553            model_rs.contains("fn dispatch_matmul_q8"),
7554            "model.rs should define dispatch_matmul_q8 method"
7555        );
7556    }
7557
7558    #[test]
7559    fn q8_model_loads_raw_bytes_not_dequantized() {
7560        let config = minimal_q8_config();
7561        let model_rs = generate_model_rs(&config).unwrap();
7562
7563        // Should NOT contain dequantization code
7564        assert!(
7565            !model_rs.contains("f16_to_f32"),
7566            "Q8_0 model should not dequantize weights to f32"
7567        );
7568        assert!(
7569            !model_rs.contains("f32_data"),
7570            "Q8_0 model should not create f32 weight data"
7571        );
7572
7573        // Should load raw Q8_0 bytes directly
7574        assert!(
7575            model_rs.contains("total_raw as u64"),
7576            "Q8_0 model should load raw bytes into Metal buffer"
7577        );
7578    }
7579
7580    #[test]
7581    fn q8_model_norms_stay_f32() {
7582        let config = minimal_q8_config();
7583        let model_rs = generate_model_rs(&config).unwrap();
7584
7585        // Norm weights should still use f32 buffers
7586        assert!(
7587            model_rs.contains("let attn_norm = next_f32_buffer"),
7588            "attn_norm should use f32 buffer even for Q8_0 models"
7589        );
7590        assert!(
7591            model_rs.contains("let ffn_norm = next_f32_buffer"),
7592            "ffn_norm should use f32 buffer even for Q8_0 models"
7593        );
7594        assert!(
7595            model_rs.contains("let norm_buf = next_f32_buffer"),
7596            "final norm should use f32 buffer even for Q8_0 models"
7597        );
7598    }
7599
7600    #[test]
7601    fn q8_model_uses_fused_weight_loading() {
7602        let config = minimal_q8_config();
7603        let model_rs = generate_model_rs(&config).unwrap();
7604
7605        // Should use fused Q8 buffer loading for QKV
7606        assert!(
7607            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7608            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7609        );
7610        // Should use fused Q8 buffer loading for gate+up
7611        assert!(
7612            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7613            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7614        );
7615        // Should still use regular q8 buffer for individual weights
7616        assert!(
7617            model_rs.contains("let o_weight = next_q8_buffer"),
7618            "Q8_0 model should use next_q8_buffer for O weight"
7619        );
7620        assert!(
7621            model_rs.contains("let down_weight = next_q8_buffer"),
7622            "Q8_0 model should use next_q8_buffer for down weight"
7623        );
7624    }
7625
7626    #[test]
7627    fn f32_model_does_not_use_q8_dispatch() {
7628        let config = minimal_config();
7629        let model_rs = generate_model_rs(&config).unwrap();
7630
7631        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
7632        let forward_start = model_rs
7633            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7634            .unwrap();
7635        let forward_body = &model_rs[forward_start..];
7636        let forward_end = forward_body
7637            .find("\n    fn dispatch_")
7638            .unwrap_or(forward_body.len());
7639        let forward_code = &forward_body[..forward_end];
7640
7641        assert!(
7642            !forward_code.contains("dispatch_matmul_q8"),
7643            "f32 model forward should not use dispatch_matmul_q8"
7644        );
7645    }
7646
7647    #[test]
7648    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7649        let config = minimal_q8_config();
7650        let model_rs = generate_model_rs(&config).unwrap();
7651
7652        assert!(
7653            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7654            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7655        );
7656    }
7657
7658    #[test]
7659    fn generated_model_has_double_buffered_prefill() {
7660        let config = minimal_config();
7661        let model_rs = generate_model_rs(&config).unwrap();
7662
7663        // MetalModel should have prev_cmd field for double-buffered prefill
7664        assert!(
7665            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7666            "MetalModel should have prev_cmd field for double-buffered prefill"
7667        );
7668
7669        // Should have forward_prefill method
7670        assert!(
7671            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7672            "MetalModel should have forward_prefill method"
7673        );
7674
7675        // forward() should drain prev_cmd at the start
7676        assert!(
7677            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7678            "forward() should drain prev_cmd from previous prefill"
7679        );
7680    }
7681
7682    #[test]
7683    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7684        let config = minimal_config();
7685        let main_rs = generate_main_rs("test-model", &config).unwrap();
7686
7687        assert!(
7688            main_rs.contains("forward_prefill"),
7689            "main.rs should use forward_prefill for intermediate prompt tokens"
7690        );
7691        assert!(
7692            main_rs.contains("double-buffered"),
7693            "main.rs should document double-buffered prefill"
7694        );
7695    }
7696
7697    #[test]
7698    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7699        let shaders = generate_metal_shaders(&minimal_config());
7700
7701        assert!(
7702            shaders.contains("packed_short4"),
7703            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7704        );
7705        assert!(
7706            shaders.contains("d0[0]"),
7707            "matmul_vec_q8 should index the wide pointer for row 0"
7708        );
7709        assert!(
7710            shaders.contains("as_type<char2>"),
7711            "matmul_vec_q8 should bitcast short lanes to char2"
7712        );
7713        assert!(
7714            shaders.contains("dot("),
7715            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7716        );
7717    }
7718
7719    // ── Q4_0 tests ──────────────────────────────────────────────────────
7720
7721    fn minimal_q4_config() -> ModelConfig {
7722        ModelConfig {
7723            architecture: Architecture::Llama,
7724            hidden_size: 64,
7725            intermediate_size: 128,
7726            num_layers: 2,
7727            num_attention_heads: 4,
7728            num_kv_heads: 4,
7729            head_dim: 16,
7730            vocab_size: 256,
7731            max_seq_len: 512,
7732            rms_norm_eps: 1e-5,
7733            rope_theta: 10000.0,
7734            dtype: DType::Q4_0,
7735            sliding_window_size: None,
7736            qkv_bias: false,
7737        }
7738    }
7739
7740    #[test]
7741    fn generated_shaders_contain_q4_kernel() {
7742        let shaders = generate_metal_shaders(&minimal_config());
7743
7744        assert!(
7745            shaders.contains("kernel void matmul_vec_q4"),
7746            "shaders should contain matmul_vec_q4 kernel"
7747        );
7748        assert!(
7749            shaders.contains("Q4_ROWS_PER_TG"),
7750            "shaders should define Q4_ROWS_PER_TG constant"
7751        );
7752        assert!(
7753            shaders.contains("Q4_ROWS_PER_SG"),
7754            "shaders should define Q4_ROWS_PER_SG constant"
7755        );
7756    }
7757
7758    #[test]
7759    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7760        let shaders = generate_metal_shaders(&minimal_config());
7761
7762        // Q4_0 kernel should use uchar4 for packed byte loads
7763        assert!(
7764            shaders.contains("uchar4"),
7765            "matmul_vec_q4 should use uchar4 for packed byte loads"
7766        );
7767        // Should unpack nibbles with &0xF and >>4
7768        assert!(
7769            shaders.contains("&0xF"),
7770            "matmul_vec_q4 should extract low nibble with &0xF"
7771        );
7772        assert!(
7773            shaders.contains(">>4"),
7774            "matmul_vec_q4 should extract high nibble with >>4"
7775        );
7776        // Should subtract 8 to convert unsigned to signed
7777        assert!(
7778            shaders.contains("-8)"),
7779            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7780        );
7781        // Should use 18-byte block size
7782        assert!(
7783            shaders.contains("blk * 18"),
7784            "matmul_vec_q4 should use 18-byte block stride"
7785        );
7786    }
7787
7788    #[test]
7789    fn q4_model_has_matmul_q4_pipeline() {
7790        let config = minimal_q4_config();
7791        let model_rs = generate_model_rs(&config).unwrap();
7792
7793        assert!(
7794            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7795            "MetalModel should have matmul_q4_pipeline field"
7796        );
7797        assert!(
7798            model_rs.contains("matmul_q4_pipeline,"),
7799            "MetalModel Self should include matmul_q4_pipeline"
7800        );
7801    }
7802
7803    #[test]
7804    fn q4_model_uses_dispatch_matmul_q4() {
7805        let config = minimal_q4_config();
7806        let model_rs = generate_model_rs(&config).unwrap();
7807
7808        assert!(
7809            model_rs.contains("dispatch_matmul_q4"),
7810            "Q4_0 model should use dispatch_matmul_q4 for projections"
7811        );
7812        assert!(
7813            model_rs.contains("fn dispatch_matmul_q4"),
7814            "model.rs should define dispatch_matmul_q4 method"
7815        );
7816    }
7817
7818    #[test]
7819    fn q4_model_loads_raw_bytes_not_dequantized() {
7820        let config = minimal_q4_config();
7821        let model_rs = generate_model_rs(&config).unwrap();
7822
7823        // Should NOT contain dequantization code
7824        assert!(
7825            !model_rs.contains("f16_to_f32"),
7826            "Q4_0 model should not dequantize weights to f32"
7827        );
7828
7829        // Should load raw Q4_0 bytes directly
7830        assert!(
7831            model_rs.contains("total_raw as u64"),
7832            "Q4_0 model should load raw bytes into Metal buffer"
7833        );
7834    }
7835
7836    #[test]
7837    fn q4_model_norms_stay_f32() {
7838        let config = minimal_q4_config();
7839        let model_rs = generate_model_rs(&config).unwrap();
7840
7841        assert!(
7842            model_rs.contains("let attn_norm = next_f32_buffer"),
7843            "attn_norm should use f32 buffer even for Q4_0 models"
7844        );
7845        assert!(
7846            model_rs.contains("let ffn_norm = next_f32_buffer"),
7847            "ffn_norm should use f32 buffer even for Q4_0 models"
7848        );
7849        assert!(
7850            model_rs.contains("let norm_buf = next_f32_buffer"),
7851            "final norm should use f32 buffer even for Q4_0 models"
7852        );
7853    }
7854
7855    #[test]
7856    fn q4_model_uses_fused_weight_loading() {
7857        let config = minimal_q4_config();
7858        let model_rs = generate_model_rs(&config).unwrap();
7859
7860        assert!(
7861            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7862            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7863        );
7864        assert!(
7865            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7866            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7867        );
7868        assert!(
7869            model_rs.contains("let o_weight = next_q4_buffer"),
7870            "Q4_0 model should use next_q4_buffer for O weight"
7871        );
7872        assert!(
7873            model_rs.contains("let down_weight = next_q4_buffer"),
7874            "Q4_0 model should use next_q4_buffer for down weight"
7875        );
7876    }
7877
7878    #[test]
7879    fn attention_flash_batch_kernel_exists() {
7880        // The flash kernel is still wired into the library (pipeline, kernel
7881        // source).  Dispatch is currently routed to the legacy path pending a
7882        // fix for a numerical issue discovered after the prompt-chunking fix.
7883        let config = minimal_config();
7884        let model_rs = generate_model_rs(&config).unwrap();
7885        let shaders = generate_metal_shaders(&config);
7886
7887        assert!(
7888            shaders.contains("kernel void attention_flash_batch"),
7889            "shaders.metal must still contain the attention_flash_batch kernel"
7890        );
7891        assert!(
7892            shaders.contains("FLASH_K_TILE"),
7893            "flash kernel must tile K/V with a FLASH_K_TILE constant"
7894        );
7895        assert!(
7896            model_rs.contains("attention_flash_batch_pipeline"),
7897            "MetalModel must register the flash attention pipeline"
7898        );
7899    }
7900
7901    #[test]
7902    fn attention_mma_flash_batch_kernel_wired() {
7903        // MMA-accelerated flash attention (issue #212).  Opt-in via
7904        // FORGE_MMA_ATTN=1 until default-enabled after broader testing.
7905        let config = minimal_config();
7906        let model_rs = generate_model_rs(&config).unwrap();
7907        let shaders = generate_metal_shaders(&config);
7908
7909        assert!(
7910            shaders.contains("kernel void attention_mma_flash_batch"),
7911            "shaders.metal must contain the MMA flash kernel"
7912        );
7913        assert!(
7914            shaders.contains("FLASH_MMA_Q_BLOCK"),
7915            "MMA flash kernel must define Q_BLOCK tiling constant"
7916        );
7917        assert!(
7918            shaders.contains("simdgroup_multiply_accumulate"),
7919            "MMA flash kernel must use hardware MMA"
7920        );
7921        assert!(
7922            model_rs.contains("attention_mma_flash_batch_pipeline"),
7923            "MetalModel must register the MMA flash pipeline"
7924        );
7925        assert!(
7926            model_rs.contains("FORGE_MMA_ATTN"),
7927            "dispatch_attention_batch must gate MMA flash on env var"
7928        );
7929    }
7930
7931    #[test]
7932    fn forward_prefill_batch_chunks_by_max_batch_size() {
7933        // Regression: prior to v0.6.4 forward_prefill_batch truncated prompts
7934        // longer than MAX_BATCH_SIZE (512) tokens via `.min(MAX_BATCH_SIZE)`,
7935        // silently dropping the middle of long prompts.  Must now loop over
7936        // MAX_BATCH_SIZE-sized chunks and carry KV-cache state across them.
7937        let config = minimal_config();
7938        let model_rs = generate_model_rs(&config).unwrap();
7939        assert!(
7940            model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
7941            "forward_prefill_batch must chunk long prompts"
7942        );
7943        assert!(
7944            !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
7945            "the old truncation path must be gone"
7946        );
7947    }
7948
7949    #[test]
7950    fn qwen2_qkv_bias_wired_through_metal_codegen() {
7951        // Issue #210: the pre-v0.6.2 Metal codegen had zero handling for
7952        // qkv_bias.  Verify that a Qwen2-style config emits the bias buffer,
7953        // loader, pipeline, and dispatch call in the expected places.
7954        let config = ModelConfig {
7955            architecture: Architecture::Qwen2,
7956            qkv_bias: true,
7957            ..minimal_config()
7958        };
7959        let model_rs = generate_model_rs(&config).unwrap();
7960
7961        assert!(
7962            model_rs.contains("qkv_bias: Buffer"),
7963            "Qwen2 LayerBuffers must declare qkv_bias field"
7964        );
7965        assert!(
7966            model_rs.contains("let qkv_bias = next_f32_buffer"),
7967            "Qwen2 layer init must load the bias from the weight blob"
7968        );
7969        assert!(
7970            model_rs.contains("add_bias_batch_pipeline"),
7971            "Qwen2 model struct must include the add_bias_batch_pipeline"
7972        );
7973        assert!(
7974            model_rs.contains("fn dispatch_add_bias_batch"),
7975            "Qwen2 codegen must emit dispatch_add_bias_batch helper"
7976        );
7977        assert!(
7978            model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
7979            "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
7980        );
7981        assert!(
7982            model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
7983            "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
7984        );
7985
7986        // The add_bias_batch MSL kernel must be in the shader source.
7987        let shaders = generate_metal_shaders(&config);
7988        assert!(
7989            shaders.contains("kernel void add_bias_batch"),
7990            "shaders.metal must contain the add_bias_batch kernel"
7991        );
7992    }
7993
7994    #[test]
7995    fn llama_does_not_emit_qkv_bias_machinery() {
7996        // Negative test: non-Qwen2 models must NOT carry the bias dispatch,
7997        // buffer, or pipeline — keeps generated code lean for Llama/Phi/etc.
7998        let config = minimal_config();
7999        assert!(!config.qkv_bias);
8000        let model_rs = generate_model_rs(&config).unwrap();
8001        assert!(
8002            !model_rs.contains("qkv_bias: Buffer"),
8003            "Llama must not have qkv_bias field"
8004        );
8005        assert!(
8006            !model_rs.contains("add_bias_batch_pipeline"),
8007            "Llama must not pull in add_bias_batch_pipeline"
8008        );
8009        assert!(
8010            !model_rs.contains("dispatch_add_bias_batch"),
8011            "Llama must not call dispatch_add_bias_batch"
8012        );
8013    }
8014
8015    #[test]
8016    fn q4_dispatch_helper_takes_compute_encoder_ref() {
8017        let config = minimal_q4_config();
8018        let model_rs = generate_model_rs(&config).unwrap();
8019
8020        assert!(
8021            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8022            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8023        );
8024    }
8025
8026    #[test]
8027    fn f32_model_does_not_use_q4_dispatch() {
8028        let config = minimal_config();
8029        let model_rs = generate_model_rs(&config).unwrap();
8030
8031        let forward_start = model_rs
8032            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8033            .unwrap();
8034        let forward_body = &model_rs[forward_start..];
8035        let forward_end = forward_body
8036            .find("\n    fn dispatch_")
8037            .unwrap_or(forward_body.len());
8038        let forward_code = &forward_body[..forward_end];
8039
8040        assert!(
8041            !forward_code.contains("dispatch_matmul_q4"),
8042            "f32 model forward should not use dispatch_matmul_q4"
8043        );
8044    }
8045
8046    #[test]
8047    fn q4_model_lm_head_uses_q4_buffer() {
8048        let config = minimal_q4_config();
8049        let model_rs = generate_model_rs(&config).unwrap();
8050
8051        assert!(
8052            model_rs.contains("let lm_head_buf = next_q4_buffer"),
8053            "Q4_0 model should use next_q4_buffer for lm_head"
8054        );
8055    }
8056
8057    #[test]
8058    fn vec_tile_size_matches_model_dimensions() {
8059        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
8060        let small = minimal_config();
8061        let shaders_small = generate_metal_shaders(&small);
8062        assert!(
8063            shaders_small.contains("vec_tile[128]"),
8064            "vec_tile should be sized to max(hidden, intermediate) = 128"
8065        );
8066
8067        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
8068        let mut large = minimal_config();
8069        large.hidden_size = 2048;
8070        large.intermediate_size = 8192;
8071        let shaders_large = generate_metal_shaders(&large);
8072        assert!(
8073            shaders_large.contains("vec_tile[8192]"),
8074            "vec_tile should be 8192 for models with intermediate=8192"
8075        );
8076        assert!(
8077            !shaders_large.contains("vec_tile[4096]"),
8078            "vec_tile should NOT be hardcoded to 4096"
8079        );
8080    }
8081
8082    #[test]
8083    fn generated_cargo_toml_has_server_deps() {
8084        let toml = generate_cargo_toml("my-model");
8085        assert!(
8086            toml.contains("tiny_http"),
8087            "Cargo.toml should depend on tiny_http"
8088        );
8089        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8090        assert!(
8091            toml.contains("serde_json"),
8092            "Cargo.toml should depend on serde_json"
8093        );
8094    }
8095
8096    #[test]
8097    fn generated_main_rs_has_serve_mode() {
8098        let config = minimal_config();
8099        let main_rs = generate_main_rs("test-model", &config).unwrap();
8100
8101        assert!(
8102            main_rs.contains("--serve"),
8103            "main.rs should parse --serve flag"
8104        );
8105        assert!(
8106            main_rs.contains("--port"),
8107            "main.rs should parse --port flag"
8108        );
8109        assert!(
8110            main_rs.contains("fn serve("),
8111            "main.rs should define serve function"
8112        );
8113        assert!(
8114            main_rs.contains("tiny_http::Server::http"),
8115            "main.rs should create tiny_http server"
8116        );
8117    }
8118
8119    #[test]
8120    fn generated_main_rs_has_chat_completions_endpoint() {
8121        let config = minimal_config();
8122        let main_rs = generate_main_rs("test-model", &config).unwrap();
8123
8124        assert!(
8125            main_rs.contains("/v1/chat/completions"),
8126            "main.rs should handle /v1/chat/completions endpoint"
8127        );
8128        assert!(
8129            main_rs.contains("/v1/models"),
8130            "main.rs should handle /v1/models endpoint"
8131        );
8132        assert!(
8133            main_rs.contains("/health"),
8134            "main.rs should handle /health endpoint"
8135        );
8136    }
8137
8138    #[test]
8139    fn generated_main_rs_has_sse_streaming() {
8140        let config = minimal_config();
8141        let main_rs = generate_main_rs("test-model", &config).unwrap();
8142
8143        assert!(
8144            main_rs.contains("text/event-stream"),
8145            "main.rs should set SSE content type for streaming"
8146        );
8147        assert!(
8148            main_rs.contains("chat.completion.chunk"),
8149            "main.rs should emit SSE chunks"
8150        );
8151        assert!(
8152            main_rs.contains("[DONE]"),
8153            "main.rs should emit [DONE] sentinel"
8154        );
8155    }
8156
8157    #[test]
8158    fn generated_main_rs_has_chat_message_formatting() {
8159        let config = minimal_config();
8160        let main_rs = generate_main_rs("test-model", &config).unwrap();
8161
8162        assert!(
8163            main_rs.contains("fn format_chat_messages"),
8164            "main.rs should define format_chat_messages function"
8165        );
8166        assert!(
8167            main_rs.contains("<|im_start|>"),
8168            "main.rs should use ChatML format"
8169        );
8170        assert!(
8171            main_rs.contains("<|im_end|>"),
8172            "main.rs should use ChatML format"
8173        );
8174    }
8175
8176    #[test]
8177    fn generated_main_rs_has_request_types() {
8178        let config = minimal_config();
8179        let main_rs = generate_main_rs("test-model", &config).unwrap();
8180
8181        assert!(
8182            main_rs.contains("struct ChatRequest"),
8183            "main.rs should define ChatRequest struct"
8184        );
8185        assert!(
8186            main_rs.contains("struct ChatMessage"),
8187            "main.rs should define ChatMessage struct"
8188        );
8189        assert!(
8190            main_rs.contains("Deserialize"),
8191            "main.rs should derive Deserialize for request types"
8192        );
8193    }
8194
8195    #[test]
8196    fn generated_model_has_reset_method() {
8197        let config = minimal_config();
8198        let model_rs = generate_model_rs(&config).unwrap();
8199
8200        assert!(
8201            model_rs.contains("pub fn reset(&mut self)"),
8202            "model.rs should have a reset() method for multi-request serving"
8203        );
8204        assert!(
8205            model_rs.contains("self.pos = 0"),
8206            "reset() should reset position to 0"
8207        );
8208    }
8209
8210    #[test]
8211    fn generated_main_rs_cli_mode_still_works() {
8212        let config = minimal_config();
8213        let main_rs = generate_main_rs("test-model", &config).unwrap();
8214
8215        // CLI mode should still be functional
8216        assert!(
8217            main_rs.contains("fn cli_mode("),
8218            "main.rs should define cli_mode function"
8219        );
8220        assert!(
8221            main_rs.contains("model.forward"),
8222            "main.rs should call model.forward"
8223        );
8224        assert!(
8225            main_rs.contains("model.forward_prefill"),
8226            "main.rs should call model.forward_prefill"
8227        );
8228    }
8229
8230    // ── Batched prefill tests ──────────────────────────────────────────
8231
8232    #[test]
8233    fn generated_shaders_contain_batch_kernels() {
8234        let shaders = generate_metal_shaders(&minimal_config());
8235
8236        assert!(
8237            shaders.contains("kernel void matmul_vec_batch"),
8238            "shaders should contain matmul_vec_batch kernel"
8239        );
8240        assert!(
8241            shaders.contains("kernel void matmul_vec_q8_batch"),
8242            "shaders should contain matmul_vec_q8_batch kernel"
8243        );
8244        assert!(
8245            shaders.contains("kernel void matmul_q8_gemm_batch"),
8246            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8247        );
8248        assert!(
8249            shaders.contains("kernel void matmul_vec_q4_batch"),
8250            "shaders should contain matmul_vec_q4_batch kernel"
8251        );
8252        assert!(
8253            shaders.contains("kernel void rms_norm_batch"),
8254            "shaders should contain rms_norm_batch kernel"
8255        );
8256        assert!(
8257            shaders.contains("kernel void silu_mul_fused_batch"),
8258            "shaders should contain silu_mul_fused_batch kernel"
8259        );
8260        assert!(
8261            shaders.contains("kernel void add_inplace_batch"),
8262            "shaders should contain add_inplace_batch kernel"
8263        );
8264        assert!(
8265            shaders.contains("kernel void copy_embedding_batch"),
8266            "shaders should contain copy_embedding_batch kernel"
8267        );
8268    }
8269
8270    #[test]
8271    fn generated_model_has_batch_pipelines() {
8272        let config = minimal_config();
8273        let model_rs = generate_model_rs(&config).unwrap();
8274
8275        for pipeline in &[
8276            "matmul_batch_pipeline",
8277            "matmul_q8_batch_pipeline",
8278            "matmul_q8_gemm_batch_pipeline",
8279            "matmul_q4_batch_pipeline",
8280            "rms_norm_batch_pipeline",
8281            "rope_batch_pipeline",
8282            "silu_mul_fused_batch_pipeline",
8283            "add_inplace_batch_pipeline",
8284            "copy_embedding_batch_pipeline",
8285            "attention_batch_pipeline",
8286            "copy_kv_batch_pipeline",
8287        ] {
8288            assert!(
8289                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8290                "MetalModel should have {pipeline} field"
8291            );
8292        }
8293    }
8294
8295    #[test]
8296    fn generated_model_has_batch_buffers() {
8297        let config = minimal_config();
8298        let model_rs = generate_model_rs(&config).unwrap();
8299
8300        for buf in &[
8301            "batch_hidden_buf",
8302            "batch_residual_buf",
8303            "batch_qkv_buf",
8304            "batch_attn_out_buf",
8305            "batch_attn_proj_buf",
8306            "batch_gate_up_buf",
8307            "batch_ffn_hidden_buf",
8308            "batch_ffn_out_buf",
8309            "batch_tokens_buf",
8310            "batch_positions_buf",
8311        ] {
8312            assert!(
8313                model_rs.contains(&format!("{buf}: Buffer")),
8314                "MetalModel should have {buf} field"
8315            );
8316        }
8317    }
8318
8319    #[test]
8320    fn generated_model_has_forward_prefill_batch() {
8321        let config = minimal_config();
8322        let model_rs = generate_model_rs(&config).unwrap();
8323
8324        assert!(
8325            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8326            "MetalModel should have forward_prefill_batch method"
8327        );
8328
8329        // forward_prefill should delegate to forward_prefill_batch
8330        assert!(
8331            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8332            "forward_prefill should delegate to forward_prefill_batch"
8333        );
8334    }
8335
8336    #[test]
8337    fn generated_model_has_max_batch_size_constant() {
8338        let config = minimal_config();
8339        let model_rs = generate_model_rs(&config).unwrap();
8340
8341        assert!(
8342            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8343            "model.rs should define MAX_BATCH_SIZE constant"
8344        );
8345    }
8346
8347    #[test]
8348    fn forward_prefill_batch_uses_batch_dispatch() {
8349        let config = minimal_config();
8350        let model_rs = generate_model_rs(&config).unwrap();
8351
8352        let batch_start = model_rs
8353            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8354            .unwrap();
8355        let batch_body = &model_rs[batch_start..];
8356        let batch_end = batch_body
8357            .find("\n    pub fn reset")
8358            .unwrap_or(batch_body.len());
8359        let batch_code = &batch_body[..batch_end];
8360
8361        // Should use batched dispatch methods
8362        assert!(
8363            batch_code.contains("dispatch_rms_norm_batch"),
8364            "forward_prefill_batch should use dispatch_rms_norm_batch"
8365        );
8366        assert!(
8367            batch_code.contains("dispatch_copy_embedding_batch"),
8368            "forward_prefill_batch should use dispatch_copy_embedding_batch"
8369        );
8370        assert!(
8371            batch_code.contains("dispatch_silu_mul_fused_batch"),
8372            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8373        );
8374        // Should use batched causal attention dispatch
8375        assert!(
8376            batch_code.contains("dispatch_attention_batch"),
8377            "forward_prefill_batch should use dispatch_attention_batch"
8378        );
8379        // Should use fused KV cache copy (both K and V in one dispatch)
8380        assert!(
8381            batch_code.contains("dispatch_copy_kv_both_batch"),
8382            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8383        );
8384        // Should use fused RoPE Q+K dispatch
8385        assert!(
8386            batch_code.contains("dispatch_rope_qk_batch"),
8387            "forward_prefill_batch should use dispatch_rope_qk_batch"
8388        );
8389    }
8390
8391    #[test]
8392    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8393        let config = minimal_q8_config();
8394        let model_rs = generate_model_rs(&config).unwrap();
8395
8396        let batch_start = model_rs
8397            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8398            .unwrap();
8399        let batch_body = &model_rs[batch_start..];
8400        let batch_end = batch_body
8401            .find("\n    pub fn reset")
8402            .unwrap_or(batch_body.len());
8403        let batch_code = &batch_body[..batch_end];
8404
8405        assert!(
8406            batch_code.contains("dispatch_matmul_q8_batch"),
8407            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8408        );
8409    }
8410
8411    #[test]
8412    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8413        let config = minimal_q4_config();
8414        let model_rs = generate_model_rs(&config).unwrap();
8415
8416        let batch_start = model_rs
8417            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8418            .unwrap();
8419        let batch_body = &model_rs[batch_start..];
8420        let batch_end = batch_body
8421            .find("\n    pub fn reset")
8422            .unwrap_or(batch_body.len());
8423        let batch_code = &batch_body[..batch_end];
8424
8425        assert!(
8426            batch_code.contains("dispatch_matmul_q4_batch"),
8427            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8428        );
8429    }
8430
8431    #[test]
8432    fn generated_main_rs_uses_batched_prefill() {
8433        let config = minimal_config();
8434        let main_rs = generate_main_rs("test-model", &config).unwrap();
8435
8436        assert!(
8437            main_rs.contains("forward_prefill_batch"),
8438            "main.rs should use forward_prefill_batch for prompt tokens"
8439        );
8440    }
8441
8442    #[test]
8443    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8444        let config = minimal_config();
8445        let model_rs = generate_model_rs(&config).unwrap();
8446
8447        let batch_start = model_rs
8448            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8449            .unwrap();
8450        let batch_body = &model_rs[batch_start..];
8451        let batch_end = batch_body
8452            .find("\n    pub fn reset")
8453            .unwrap_or(batch_body.len());
8454        let batch_code = &batch_body[..batch_end];
8455
8456        assert!(
8457            batch_code.contains("dispatch_matmul_batch"),
8458            "f32 forward_prefill_batch should use dispatch_matmul_batch"
8459        );
8460        // Should NOT use Q8 or Q4 batch dispatch
8461        assert!(
8462            !batch_code.contains("dispatch_matmul_q8_batch"),
8463            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8464        );
8465        assert!(
8466            !batch_code.contains("dispatch_matmul_q4_batch"),
8467            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8468        );
8469    }
8470
8471    #[test]
8472    fn forward_uses_cpu_embedding_lookup() {
8473        let config = minimal_config();
8474        let model_rs = generate_model_rs(&config).unwrap();
8475
8476        // Find just the forward() body (not forward_profile)
8477        let forward_start = model_rs
8478            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8479            .unwrap();
8480        let forward_body = &model_rs[forward_start..];
8481        let forward_end = forward_body
8482            .find("\n    pub fn forward_profile")
8483            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8484            .unwrap_or(forward_body.len());
8485        let forward_code = &forward_body[..forward_end];
8486
8487        // forward() should use CPU memcpy for embedding lookup (unified memory)
8488        assert!(
8489            forward_code.contains("embed_buf.contents()"),
8490            "forward() should access embed_buf via CPU unified memory for embedding lookup"
8491        );
8492        assert!(
8493            forward_code.contains("copy_nonoverlapping"),
8494            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8495        );
8496        // forward() should NOT use GPU dispatch for embedding
8497        assert!(
8498            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8499            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8500        );
8501    }
8502
8503    #[test]
8504    fn forward_profile_method_exists() {
8505        let config = minimal_config();
8506        let model_rs = generate_model_rs(&config).unwrap();
8507
8508        assert!(
8509            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8510            "MetalModel should have forward_profile() method"
8511        );
8512        // Profile method should print timing information
8513        assert!(
8514            model_rs.contains("[profile]"),
8515            "forward_profile() should print timing with [profile] prefix"
8516        );
8517        assert!(
8518            model_rs.contains("d_embed"),
8519            "forward_profile() should measure embedding time"
8520        );
8521        assert!(
8522            model_rs.contains("d_layers"),
8523            "forward_profile() should measure layer time"
8524        );
8525        assert!(
8526            model_rs.contains("d_logits"),
8527            "forward_profile() should measure logits time"
8528        );
8529    }
8530
8531    #[test]
8532    fn generated_cli_has_profile_flag() {
8533        let config = minimal_config();
8534        let main_rs = generate_main_rs("test-model", &config).unwrap();
8535
8536        assert!(
8537            main_rs.contains("--profile"),
8538            "CLI should support --profile flag"
8539        );
8540        assert!(
8541            main_rs.contains("forward_profile"),
8542            "CLI should call forward_profile when --profile is set"
8543        );
8544    }
8545
8546    #[test]
8547    fn generated_cli_has_thermal_yield() {
8548        let config = minimal_config();
8549        let main_rs = generate_main_rs("test-model", &config).unwrap();
8550
8551        assert!(
8552            main_rs.contains("yield_now()"),
8553            "CLI generation loop should include thread::yield_now() for thermal management"
8554        );
8555    }
8556
8557    // ── Real-world validation tests ──────────────────────────────────────
8558
8559    #[test]
8560    fn generated_forward_handles_single_token_prompt() {
8561        // With a single token (the first prompt token), forward() should work
8562        // at pos=0 where seq_len=1. The attention kernel must handle the case
8563        // where there is only one KV entry (no prefill context).
8564        let config = minimal_config();
8565        let model_rs = generate_model_rs(&config).unwrap();
8566
8567        // The forward function should accept any u32 token_id (no minimum pos guard)
8568        let forward_start = model_rs
8569            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8570            .expect("forward() must exist");
8571        let forward_body = &model_rs[forward_start..forward_start + 400];
8572
8573        // Should NOT require pos > 0 or seq_len > 1
8574        assert!(
8575            !forward_body.contains("assert!(self.pos > 0"),
8576            "forward() must accept pos=0 (first token with no prefill)"
8577        );
8578
8579        // The attention kernel should handle seq_len=1 via the pos field
8580        assert!(
8581            model_rs.contains("self.pos"),
8582            "forward() should use self.pos to track sequence position"
8583        );
8584    }
8585
8586    #[test]
8587    fn generated_reset_clears_kv_cache_position() {
8588        // After reset(), the model should be in a clean state. The pos field
8589        // must be 0 so new generation starts from scratch.
8590        let config = minimal_config();
8591        let model_rs = generate_model_rs(&config).unwrap();
8592
8593        let reset_start = model_rs
8594            .find("pub fn reset(&mut self)")
8595            .expect("reset() must exist");
8596        let reset_body = &model_rs[reset_start..reset_start + 200];
8597
8598        // Reset must zero the position counter
8599        assert!(
8600            reset_body.contains("self.pos = 0"),
8601            "reset() must set self.pos = 0"
8602        );
8603
8604        // Verify reset clears prev_cmd (double-buffering state)
8605        assert!(
8606            reset_body.contains("self.prev_cmd = None"),
8607            "reset() should clear prev_cmd for clean command buffer state"
8608        );
8609    }
8610
8611    #[test]
8612    fn generated_serve_handles_empty_messages_gracefully() {
8613        // The serve endpoint should not crash when receiving an empty messages array.
8614        // The format_chat_messages function should handle this gracefully.
8615        let config = minimal_config();
8616        let main_rs = generate_main_rs("test-model", &config).unwrap();
8617
8618        // The format_chat_messages function should exist and handle empty input
8619        let format_fn_start = main_rs
8620            .find("fn format_chat_messages")
8621            .expect("format_chat_messages must exist");
8622        let format_fn_body =
8623            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8624
8625        // It should iterate over messages (an empty slice produces an empty loop)
8626        assert!(
8627            format_fn_body.contains("for msg in messages"),
8628            "format_chat_messages should iterate over the messages slice"
8629        );
8630        // It should always append the assistant prompt suffix
8631        assert!(
8632            format_fn_body.contains("<|im_start|>assistant"),
8633            "format_chat_messages should always append assistant prompt header"
8634        );
8635
8636        // The serve function should call model.reset() before each request
8637        let serve_fn_start = main_rs
8638            .find("fn serve(")
8639            .expect("serve function must exist");
8640        let serve_fn_body = &main_rs[serve_fn_start..];
8641        assert!(
8642            serve_fn_body.contains("model.reset()"),
8643            "serve function should reset model between requests"
8644        );
8645    }
8646
8647    #[test]
8648    fn generated_model_forward_increments_pos() {
8649        // Each forward() call must increment self.pos so the next token
8650        // uses the correct RoPE position and KV cache offset.
8651        let config = minimal_config();
8652        let model_rs = generate_model_rs(&config).unwrap();
8653
8654        let forward_start = model_rs
8655            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8656            .unwrap();
8657        let forward_body = &model_rs[forward_start..];
8658        let forward_end = forward_body
8659            .find("\n    pub fn forward_profile")
8660            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8661            .or_else(|| forward_body.find("\n    fn dispatch_"))
8662            .unwrap_or(forward_body.len());
8663        let forward_code = &forward_body[..forward_end];
8664
8665        assert!(
8666            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8667            "forward() must increment self.pos after processing a token"
8668        );
8669    }
8670}