Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    r#"//
122// Auto-generated by ForgeLLM Metal codegen.
123// Metal Shading Language compute kernels for transformer inference.
124//
125// Optimized with simdgroup cooperative reductions, shared memory vector
126// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
127// lane, and fast:: math intrinsics for Apple Silicon throughput.
128//
129
130#include <metal_stdlib>
131#include <metal_simdgroup_matrix>
132using namespace metal;
133
134// ── Constants ───────────────────────────────────────────────────────────
135// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
136// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
137// per shared memory load vs 4-row, improving ILP and reducing launches.
138constant constexpr uint SIMDGROUPS_PER_TG = 8;
139constant constexpr uint ROWS_PER_SIMDGROUP = 8;
140constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
141
142// ── matmul_vec ──────────────────────────────────────────────────────────
143// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
144// Uses simdgroup cooperative dot product with shared memory vector caching
145// and float4 vectorized loads. Each simdgroup processes 8 rows for better
146// shared memory reuse (8x vector reuse per load) and instruction-level
147// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
148kernel void matmul_vec(
149    device const float* matrix [[buffer(0)]],
150    device const float* vector [[buffer(1)]],
151    device float* output       [[buffer(2)]],
152    constant uint& rows        [[buffer(3)]],
153    constant uint& cols        [[buffer(4)]],
154    uint tgid [[threadgroup_position_in_grid]],
155    uint tid [[thread_index_in_threadgroup]],
156    uint simd_lane [[thread_index_in_simdgroup]],
157    uint simd_id [[simdgroup_index_in_threadgroup]])
158{
159    // Cooperatively load vector into threadgroup shared memory
160    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
161    for (uint i = tid; i < cols; i += 256) {
162        vec_tile[i] = vector[i];
163    }
164    threadgroup_barrier(mem_flags::mem_threadgroup);
165
166    // Each simdgroup handles 8 consecutive rows
167    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
168    if (row_base >= rows) return;
169
170    uint base0 = row_base * cols;
171    uint base1 = (row_base + 1) * cols;
172    uint base2 = (row_base + 2) * cols;
173    uint base3 = (row_base + 3) * cols;
174    uint base4 = (row_base + 4) * cols;
175    uint base5 = (row_base + 5) * cols;
176    uint base6 = (row_base + 6) * cols;
177    uint base7 = (row_base + 7) * cols;
178
179    // float4 vectorized accumulation across 8 rows
180    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
181    float4 sum4_0 = float4(0.0f);
182    float4 sum4_1 = float4(0.0f);
183    float4 sum4_2 = float4(0.0f);
184    float4 sum4_3 = float4(0.0f);
185    float4 sum4_4 = float4(0.0f);
186    float4 sum4_5 = float4(0.0f);
187    float4 sum4_6 = float4(0.0f);
188    float4 sum4_7 = float4(0.0f);
189
190    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
191        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
192        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
193        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
194        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
195        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
196        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
197        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
198        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
199        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
200    }
201
202    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
203    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
204    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
205    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
206    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
207    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
208    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
209    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
210
211    // Handle remaining elements (cols not divisible by 128)
212    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
213        float vv = vec_tile[j];
214        sum0 += matrix[base0 + j] * vv;
215        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
216        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
217        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
218        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
219        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
220        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
221        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
222    }
223
224    // Simdgroup hardware warp-level reduction
225    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
226    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
227    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
228    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
229
230    // Only first lane writes the results
231    if (simd_lane == 0) {
232        if (row_base     < rows) output[row_base]     = sum0;
233        if (row_base + 1 < rows) output[row_base + 1] = sum1;
234        if (row_base + 2 < rows) output[row_base + 2] = sum2;
235        if (row_base + 3 < rows) output[row_base + 3] = sum3;
236        if (row_base + 4 < rows) output[row_base + 4] = sum4;
237        if (row_base + 5 < rows) output[row_base + 5] = sum5;
238        if (row_base + 6 < rows) output[row_base + 6] = sum6;
239        if (row_base + 7 < rows) output[row_base + 7] = sum7;
240    }
241}
242
243// ── rms_norm ────────────────────────────────────────────────────────────
244// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
245// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
246// via shared memory for minimal synchronization overhead.
247kernel void rms_norm(
248    device const float* input   [[buffer(0)]],
249    device const float* weight  [[buffer(1)]],
250    device float* output        [[buffer(2)]],
251    constant uint& n            [[buffer(3)]],
252    constant float& eps         [[buffer(4)]],
253    uint tid [[thread_index_in_threadgroup]])
254{
255    // Each thread accumulates partial sum-of-squares
256    float sum_sq = 0.0f;
257    for (uint i = tid; i < n; i += 256) {
258        float v = input[i];
259        sum_sq += v * v;
260    }
261
262    // Simdgroup-level reduction (hardware warp sum)
263    sum_sq = simd_sum(sum_sq);
264
265    // Cross-simdgroup reduction via shared memory
266    threadgroup float shared[8];
267    uint simd_id = tid / 32;
268    uint simd_lane = tid % 32;
269    if (simd_lane == 0) {
270        shared[simd_id] = sum_sq;
271    }
272    threadgroup_barrier(mem_flags::mem_threadgroup);
273
274    // First thread computes final inverse RMS
275    if (tid == 0) {
276        float total = 0.0f;
277        for (uint i = 0; i < 8; i++) {
278            total += shared[i];
279        }
280        shared[0] = fast::rsqrt(total / float(n) + eps);
281    }
282    threadgroup_barrier(mem_flags::mem_threadgroup);
283
284    float inv_rms = shared[0];
285
286    // Normalize
287    for (uint i = tid; i < n; i += 256) {
288        output[i] = input[i] * inv_rms * weight[i];
289    }
290}
291
292// ── rope ────────────────────────────────────────────────────────────────
293// Rotary Position Embedding applied in-place.
294// Each thread handles one (head, pair) combination.
295kernel void rope(
296    device float* data        [[buffer(0)]],
297    constant uint& num_heads  [[buffer(1)]],
298    constant uint& head_dim   [[buffer(2)]],
299    constant uint& pos        [[buffer(3)]],
300    constant float& theta     [[buffer(4)]],
301    uint id [[thread_position_in_grid]])
302{
303    uint half_dim = head_dim / 2;
304    uint total_pairs = num_heads * half_dim;
305    if (id >= total_pairs) return;
306
307    uint h = id / half_dim;
308    uint i = id % half_dim;
309    uint off = h * head_dim;
310
311    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
312    float angle = float(pos) * freq;
313    float c = cos(angle);
314    float s = sin(angle);
315
316    float x0 = data[off + 2 * i];
317    float x1 = data[off + 2 * i + 1];
318    data[off + 2 * i]     = x0 * c - x1 * s;
319    data[off + 2 * i + 1] = x0 * s + x1 * c;
320}
321
322// ── softmax ─────────────────────────────────────────────────────────────
323// Numerically stable softmax over a 1-D array.
324// Single-threadgroup kernel with cooperative reduction.
325kernel void softmax(
326    device float* data       [[buffer(0)]],
327    constant uint& n         [[buffer(1)]],
328    uint tid [[thread_index_in_threadgroup]],
329    uint tg_size [[threads_per_threadgroup]])
330{
331    threadgroup float shared_val[256];
332
333    // Pass 1: find max
334    float local_max = -INFINITY;
335    for (uint i = tid; i < n; i += tg_size) {
336        local_max = max(local_max, data[i]);
337    }
338    shared_val[tid] = local_max;
339    threadgroup_barrier(mem_flags::mem_threadgroup);
340
341    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
342        if (tid < stride) {
343            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
344        }
345        threadgroup_barrier(mem_flags::mem_threadgroup);
346    }
347    float max_val = shared_val[0];
348    threadgroup_barrier(mem_flags::mem_threadgroup);
349
350    // Pass 2: exp and sum
351    float local_sum = 0.0f;
352    for (uint i = tid; i < n; i += tg_size) {
353        float e = fast::exp(data[i] - max_val);
354        data[i] = e;
355        local_sum += e;
356    }
357    shared_val[tid] = local_sum;
358    threadgroup_barrier(mem_flags::mem_threadgroup);
359
360    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
361        if (tid < stride) {
362            shared_val[tid] += shared_val[tid + stride];
363        }
364        threadgroup_barrier(mem_flags::mem_threadgroup);
365    }
366    float inv_sum = 1.0f / shared_val[0];
367    threadgroup_barrier(mem_flags::mem_threadgroup);
368
369    // Pass 3: normalize
370    for (uint i = tid; i < n; i += tg_size) {
371        data[i] *= inv_sum;
372    }
373}
374
375// ── silu_mul ────────────────────────────────────────────────────────────
376// Fused SiLU activation * element-wise multiply:
377//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
378kernel void silu_mul(
379    device const float* gate [[buffer(0)]],
380    device const float* up   [[buffer(1)]],
381    device float* output     [[buffer(2)]],
382    constant uint& n         [[buffer(3)]],
383    uint id [[thread_position_in_grid]])
384{
385    if (id >= n) return;
386    float g = gate[id];
387    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
388}
389
390// ── silu_mul_fused ─────────────────────────────────────────────────────
391// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
392//   gate = gate_up[0..n], up = gate_up[n..2*n]
393//   output[i] = silu(gate_up[i]) * gate_up[n + i]
394kernel void silu_mul_fused(
395    device const float* gate_up [[buffer(0)]],
396    device float* output        [[buffer(1)]],
397    constant uint& n            [[buffer(2)]],
398    uint id [[thread_position_in_grid]])
399{
400    if (id >= n) return;
401    float g = gate_up[id];
402    float u = gate_up[n + id];
403    output[id] = (g / (1.0f + fast::exp(-g))) * u;
404}
405
406// ── elementwise_add ─────────────────────────────────────────────────────
407// Residual connection: output[i] = a[i] + b[i]
408kernel void elementwise_add(
409    device const float* a  [[buffer(0)]],
410    device const float* b  [[buffer(1)]],
411    device float* output   [[buffer(2)]],
412    constant uint& n       [[buffer(3)]],
413    uint id [[thread_position_in_grid]])
414{
415    if (id >= n) return;
416    output[id] = a[id] + b[id];
417}
418
419// ── copy_buffer ─────────────────────────────────────────────────────────
420// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
421// transitions. Used for KV cache updates and embedding lookup.
422kernel void copy_buffer(
423    device const float* src [[buffer(0)]],
424    device float* dst       [[buffer(1)]],
425    constant uint& count    [[buffer(2)]],
426    uint id [[thread_position_in_grid]])
427{
428    if (id < count) dst[id] = src[id];
429}
430
431// ── copy_offset ─────────────────────────────────────────────────────────
432// Copy with source offset (in floats). Used for embedding table lookup
433// where we need to copy a specific row from a large table.
434kernel void copy_offset(
435    device const float* src     [[buffer(0)]],
436    device float* dst           [[buffer(1)]],
437    constant uint& src_offset   [[buffer(2)]],  // in floats
438    constant uint& count        [[buffer(3)]],
439    uint id [[thread_position_in_grid]])
440{
441    if (id < count) dst[id] = src[src_offset + id];
442}
443
444// ── add_inplace ─────────────────────────────────────────────────────────
445// In-place residual connection: a[i] += b[i]
446// Avoids a separate blit copy for residual add, reducing encoder overhead.
447kernel void add_inplace(
448    device float* a        [[buffer(0)]],
449    device const float* b  [[buffer(1)]],
450    constant uint& n       [[buffer(2)]],
451    uint id [[thread_position_in_grid]])
452{
453    if (id >= n) return;
454    a[id] += b[id];
455}
456
457// ── matmul_vec_q8 ─────────────────────────────────────────────────────
458// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
459// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
460// Operates directly on quantized weights to halve memory bandwidth vs f32,
461// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
462//
463// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
464// because int8->float conversion doubles register demand.  Fully unrolled
465// inner loop with float4 vector loads from shared memory eliminates loop
466// overhead and enables better instruction scheduling.
467// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
468constant constexpr uint Q8_ROWS_PER_SG = 4;
469constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
470
471// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
472// doubles ALU work, so the same register budget applies).
473constant constexpr uint Q4_ROWS_PER_SG = 4;
474constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
475
476kernel void matmul_vec_q8(
477    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
478    device const float* vector   [[buffer(1)]],  // f32 input
479    device float* output         [[buffer(2)]],
480    constant uint& rows          [[buffer(3)]],
481    constant uint& cols          [[buffer(4)]],  // number of elements per row
482    uint tgid [[threadgroup_position_in_grid]],
483    uint tid [[thread_index_in_threadgroup]],
484    uint simd_lane [[thread_index_in_simdgroup]],
485    uint simd_id [[simdgroup_index_in_threadgroup]])
486{
487    // Load vector into shared memory
488    threadgroup float vec_tile[VEC_TILE_SIZE];
489    for (uint i = tid; i < cols; i += 256) {
490        vec_tile[i] = vector[i];
491    }
492    threadgroup_barrier(mem_flags::mem_threadgroup);
493
494    // Each simdgroup handles 4 consecutive rows (lower register pressure)
495    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
496    if (row_base >= rows) return;
497
498    // Q8_0: each block is 34 bytes for 32 elements
499    uint blocks_per_row = cols / 32;
500    uint row_bytes = blocks_per_row * 34;
501
502    // Pointers to each row's Q8_0 data
503    device const uchar* r0 = matrix + row_base * row_bytes;
504    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
505    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
506    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
507
508    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
509
510    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
511        uint bb = blk * 34;
512        uint vb = blk * 32;
513
514        // Prefetch all 4 scales
515        float sc0 = float(*(device const half*)(r0 + bb));
516        float sc1 = float(*(device const half*)(r1 + bb));
517        float sc2 = float(*(device const half*)(r2 + bb));
518        float sc3 = float(*(device const half*)(r3 + bb));
519
520        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
521        // Q8_0 block layout where the int8 data starts at offset +2 from a
522        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
523        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
524        // reduction in memory transactions. Metal's char16/packed_char16 are
525        // reserved types and packed_*int4 require >=4-byte alignment which
526        // this layout does not provide, so packed_short4 is the widest valid
527        // vectorized load.
528        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
529        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
530        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
531        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
532
533        // Load all 8 float4 vector values for this 32-element block from shared memory
534        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
535        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
536        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
537        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
538        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
539        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
540        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
541        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
542
543        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
544        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
545        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
546            short4 _s = short4(SHORT4); \
547            char2 _a = as_type<char2>(_s.x); \
548            char2 _b = as_type<char2>(_s.y); \
549            char2 _c = as_type<char2>(_s.z); \
550            char2 _d = as_type<char2>(_s.w); \
551            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
552            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
553        }
554
555        float4 f0, f1;
556        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
557
558        // Row 0: 4 short4 loads cover 32 int8 weights
559        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
560        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
561        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
562        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
563
564        // Row 1
565        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
566        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
567        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
568        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
569
570        // Row 2
571        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
572        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
573        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
574        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
575
576        // Row 3
577        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
578        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
579        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
580        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
581
582        #undef Q8_UNPACK8
583
584        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
585    }
586
587    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
588    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
589
590    if (simd_lane == 0) {
591        if (row_base     < rows) output[row_base]     = sum0;
592        if (row_base + 1 < rows) output[row_base + 1] = sum1;
593        if (row_base + 2 < rows) output[row_base + 2] = sum2;
594        if (row_base + 3 < rows) output[row_base + 3] = sum3;
595    }
596}
597
598// ── matmul_vec_q4 ─────────────────────────────────────────────────────
599// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
600// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
601// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
602// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
603//
604// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
605// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
606kernel void matmul_vec_q4(
607    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
608    device const float* vector   [[buffer(1)]],  // f32 input
609    device float* output         [[buffer(2)]],
610    constant uint& rows          [[buffer(3)]],
611    constant uint& cols          [[buffer(4)]],  // number of elements per row
612    uint tgid [[threadgroup_position_in_grid]],
613    uint tid [[thread_index_in_threadgroup]],
614    uint simd_lane [[thread_index_in_simdgroup]],
615    uint simd_id [[simdgroup_index_in_threadgroup]])
616{
617    // Load vector into shared memory
618    threadgroup float vec_tile[VEC_TILE_SIZE];
619    for (uint i = tid; i < cols; i += 256) {
620        vec_tile[i] = vector[i];
621    }
622    threadgroup_barrier(mem_flags::mem_threadgroup);
623
624    // Each simdgroup handles 4 consecutive rows
625    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
626    if (row_base >= rows) return;
627
628    // Q4_0: each block is 18 bytes for 32 elements
629    uint blocks_per_row = cols / 32;
630    uint row_bytes = blocks_per_row * 18;
631
632    // Pointers to each row's Q4_0 data
633    device const uchar* r0 = matrix + row_base * row_bytes;
634    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
635    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
636    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
637
638    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
639
640    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
641        uint bb = blk * 18;
642        uint vb = blk * 32;
643
644        // Prefetch all 4 scales
645        float sc0 = float(*(device const half*)(r0 + bb));
646        float sc1 = float(*(device const half*)(r1 + bb));
647        float sc2 = float(*(device const half*)(r2 + bb));
648        float sc3 = float(*(device const half*)(r3 + bb));
649
650        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
651        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
652        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
653        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
654        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
655
656        // Load 8 float4 vector values for 32 elements from shared memory
657        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
658        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
659        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
660        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
661        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
662        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
663        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
664        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
665        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
666
667        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
668        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
669        float bd0=0, bd1=0, bd2=0, bd3=0;
670        uchar4 b;
671
672        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
673        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
674                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
675                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
676                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
677        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
678                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
679                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
680                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
681        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
682                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
683                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
684                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
685        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
686                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
687                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
688                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
689
690        // Row 1
691        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
692                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
693                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
694                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
695        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
696                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
697                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
698                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
699        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
700                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
701                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
702                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
703        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
704                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
705                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
706                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
707
708        // Row 2
709        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
710                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
711                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
712                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
713        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
714                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
715                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
716                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
717        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
718                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
719                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
720                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
721        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
722                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
723                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
724                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
725
726        // Row 3
727        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
728                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
729                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
730                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
731        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
732                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
733                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
734                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
735        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
736                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
737                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
738                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
739        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
740                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
741                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
742                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
743
744        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
745    }
746
747    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
748    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
749
750    if (simd_lane == 0) {
751        if (row_base     < rows) output[row_base]     = sum0;
752        if (row_base + 1 < rows) output[row_base + 1] = sum1;
753        if (row_base + 2 < rows) output[row_base + 2] = sum2;
754        if (row_base + 3 < rows) output[row_base + 3] = sum3;
755    }
756}
757
758// ── attention ───────────────────────────────────────────────────────────
759// Single-query attention with simdgroup cooperative reductions.
760// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
761// with simd_max/simd_sum reductions, then weighted sum of V.
762// Each threadgroup handles one head with 256 threads (8 simdgroups).
763//
764// Buffers:
765//   q:       [num_heads * head_dim]       current query
766//   k_cache: [max_seq_len * num_kv_heads * head_dim]
767//   v_cache: [max_seq_len * num_kv_heads * head_dim]
768//   output:  [num_heads * head_dim]
769kernel void attention(
770    device const float* q        [[buffer(0)]],
771    device const float* k_cache  [[buffer(1)]],
772    device const float* v_cache  [[buffer(2)]],
773    device float* output         [[buffer(3)]],
774    constant uint& seq_len       [[buffer(4)]],
775    constant uint& num_heads     [[buffer(5)]],
776    constant uint& num_kv_heads  [[buffer(6)]],
777    constant uint& head_dim      [[buffer(7)]],
778    uint tgid [[threadgroup_position_in_grid]],
779    uint tid [[thread_index_in_threadgroup]],
780    uint simd_lane [[thread_index_in_simdgroup]],
781    uint simd_id [[simdgroup_index_in_threadgroup]])
782{
783    uint head = tgid;
784    if (head >= num_heads) return;
785    uint kv_head = head / (num_heads / num_kv_heads);
786
787    uint q_off = head * head_dim;
788
789    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
790    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
791    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
792    // most generation steps have short effective context.
793    threadgroup float scores[2048];  // max seq_len for generation phase
794
795    for (uint s = simd_id; s < seq_len; s += 8) {
796        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
797        float dot = 0.0;
798        for (uint d = simd_lane; d < head_dim; d += 32) {
799            dot += q[q_off + d] * k_cache[k_off + d];
800        }
801        dot = simd_sum(dot);
802        if (simd_lane == 0) {
803            scores[s] = dot * fast::rsqrt(float(head_dim));
804        }
805    }
806    threadgroup_barrier(mem_flags::mem_threadgroup);
807
808    // Step 2: Softmax over scores (cooperative)
809    // Find max
810    float local_max = -INFINITY;
811    for (uint s = tid; s < seq_len; s += 256) {
812        local_max = max(local_max, scores[s]);
813    }
814    local_max = simd_max(local_max);
815    threadgroup float shared_max[8];
816    if (simd_lane == 0) shared_max[simd_id] = local_max;
817    threadgroup_barrier(mem_flags::mem_threadgroup);
818    if (tid == 0) {
819        float m = shared_max[0];
820        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
821        shared_max[0] = m;
822    }
823    threadgroup_barrier(mem_flags::mem_threadgroup);
824    float max_val = shared_max[0];
825
826    // Exp and sum
827    float local_sum = 0.0;
828    for (uint s = tid; s < seq_len; s += 256) {
829        scores[s] = fast::exp(scores[s] - max_val);
830        local_sum += scores[s];
831    }
832    local_sum = simd_sum(local_sum);
833    threadgroup float shared_sum[8];
834    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
835    threadgroup_barrier(mem_flags::mem_threadgroup);
836    if (tid == 0) {
837        float total = 0.0;
838        for (uint i = 0; i < 8; i++) total += shared_sum[i];
839        shared_sum[0] = 1.0 / total;
840    }
841    threadgroup_barrier(mem_flags::mem_threadgroup);
842    float inv_sum = shared_sum[0];
843
844    for (uint s = tid; s < seq_len; s += 256) {
845        scores[s] *= inv_sum;
846    }
847    threadgroup_barrier(mem_flags::mem_threadgroup);
848
849    // Step 3: Weighted sum of V: output = scores · V
850    // Each thread handles a range of head_dim dimensions.
851    // Process 4 sequence positions at a time for better ILP and reduced
852    // loop overhead (float4 score gather, 4 V loads per iteration).
853    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
854    uint v_stride = num_kv_heads * head_dim;
855    for (uint d = tid; d < head_dim; d += 256) {
856        float acc = 0.0;
857        uint v_base = kv_head * head_dim + d;
858        for (uint s = 0; s < seq_len4; s += 4) {
859            float sc0 = scores[s];
860            float sc1 = scores[s + 1];
861            float sc2 = scores[s + 2];
862            float sc3 = scores[s + 3];
863            acc += sc0 * v_cache[s * v_stride + v_base]
864                 + sc1 * v_cache[(s+1) * v_stride + v_base]
865                 + sc2 * v_cache[(s+2) * v_stride + v_base]
866                 + sc3 * v_cache[(s+3) * v_stride + v_base];
867        }
868        for (uint s = seq_len4; s < seq_len; s++) {
869            acc += scores[s] * v_cache[s * v_stride + v_base];
870        }
871        output[q_off + d] = acc;
872    }
873}
874
875// ── Batched prefill kernels ────────────────────────────────────────────
876// These kernels process M input vectors against the same weight matrix
877// in a single dispatch, converting mat-vec into mat-mat for better GPU
878// utilization during prompt prefill.
879
880// ── rms_norm_batch ─────────────────────────────────────────────────────
881// RMS normalization for a batch of vectors.
882// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
883// Grid: M threadgroups (one per token).
884kernel void rms_norm_batch(
885    device const float* input   [[buffer(0)]],  // [M, n]
886    device const float* weight  [[buffer(1)]],  // [n]
887    device float* output        [[buffer(2)]],  // [M, n]
888    constant uint& n            [[buffer(3)]],
889    constant float& eps         [[buffer(4)]],
890    constant uint& num_tokens   [[buffer(5)]],
891    uint tgid [[threadgroup_position_in_grid]],
892    uint tid [[thread_index_in_threadgroup]])
893{
894    if (tgid >= num_tokens) return;
895
896    uint base = tgid * n;
897
898    float sum_sq = 0.0f;
899    for (uint i = tid; i < n; i += 256) {
900        float v = input[base + i];
901        sum_sq += v * v;
902    }
903
904    sum_sq = simd_sum(sum_sq);
905
906    threadgroup float shared[8];
907    uint simd_id = tid / 32;
908    uint simd_lane = tid % 32;
909    if (simd_lane == 0) {
910        shared[simd_id] = sum_sq;
911    }
912    threadgroup_barrier(mem_flags::mem_threadgroup);
913
914    if (tid == 0) {
915        float total = 0.0f;
916        for (uint i = 0; i < 8; i++) {
917            total += shared[i];
918        }
919        shared[0] = fast::rsqrt(total / float(n) + eps);
920    }
921    threadgroup_barrier(mem_flags::mem_threadgroup);
922
923    float inv_rms = shared[0];
924
925    for (uint i = tid; i < n; i += 256) {
926        output[base + i] = input[base + i] * inv_rms * weight[i];
927    }
928}
929
930// ── rope_batch ─────────────────────────────────────────────────────────
931// Rotary Position Embedding for a batch of vectors with different positions.
932// data layout: [M, num_heads * head_dim], positions: [M]
933// Each thread handles one (token, head, pair) combination.
934kernel void rope_batch(
935    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
936    constant uint& num_heads     [[buffer(1)]],
937    constant uint& head_dim      [[buffer(2)]],
938    device const uint* positions  [[buffer(3)]],  // [M] position per token
939    constant float& theta        [[buffer(4)]],
940    constant uint& num_tokens    [[buffer(5)]],
941    uint id [[thread_position_in_grid]])
942{
943    uint half_dim = head_dim / 2;
944    uint pairs_per_token = num_heads * half_dim;
945    uint total = num_tokens * pairs_per_token;
946    if (id >= total) return;
947
948    uint token = id / pairs_per_token;
949    uint rem = id % pairs_per_token;
950    uint h = rem / half_dim;
951    uint i = rem % half_dim;
952    uint off = token * (num_heads * head_dim) + h * head_dim;
953
954    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
955    float angle = float(positions[token]) * freq;
956    float c = cos(angle);
957    float s = sin(angle);
958
959    float x0 = data[off + 2 * i];
960    float x1 = data[off + 2 * i + 1];
961    data[off + 2 * i]     = x0 * c - x1 * s;
962    data[off + 2 * i + 1] = x0 * s + x1 * c;
963}
964
965// ── silu_mul_fused_batch ───────────────────────────────────────────────
966// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
967// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
968kernel void silu_mul_fused_batch(
969    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
970    device float* output        [[buffer(1)]],  // [M, n]
971    constant uint& n            [[buffer(2)]],
972    constant uint& num_tokens   [[buffer(3)]],
973    uint id [[thread_position_in_grid]])
974{
975    uint total = num_tokens * n;
976    if (id >= total) return;
977    uint token = id / n;
978    uint i = id % n;
979    uint gu_base = token * 2 * n;
980    float g = gate_up[gu_base + i];
981    float u = gate_up[gu_base + n + i];
982    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
983}
984
985// ── add_inplace_batch ──────────────────────────────────────────────────
986// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
987kernel void add_inplace_batch(
988    device float* a        [[buffer(0)]],  // [M * n]
989    device const float* b  [[buffer(1)]],  // [M * n]
990    constant uint& total   [[buffer(2)]],  // M * n
991    uint id [[thread_position_in_grid]])
992{
993    if (id >= total) return;
994    a[id] += b[id];
995}
996
997// ── copy_embedding_batch ───────────────────────────────────────────────
998// Copy M embedding rows from embedding table to a contiguous batch buffer.
999// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1000kernel void copy_embedding_batch(
1001    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1002    device float* output        [[buffer(1)]],  // [M, dim]
1003    device const uint* tokens   [[buffer(2)]],  // [M]
1004    constant uint& dim          [[buffer(3)]],
1005    constant uint& num_tokens   [[buffer(4)]],
1006    uint id [[thread_position_in_grid]])
1007{
1008    uint total = num_tokens * dim;
1009    if (id >= total) return;
1010    uint token_idx = id / dim;
1011    uint d = id % dim;
1012    output[id] = embed[tokens[token_idx] * dim + d];
1013}
1014
1015// ── matmul_vec_batch ───────────────────────────────────────────────────
1016// Batched matrix-vector multiply: process M input vectors against the same
1017// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1018// Each threadgroup handles one (token, row_group) pair.
1019kernel void matmul_vec_batch(
1020    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1021    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1022    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1023    constant uint& num_tokens   [[buffer(3)]],  // M
1024    constant uint& rows         [[buffer(4)]],
1025    constant uint& cols         [[buffer(5)]],
1026    uint tgid [[threadgroup_position_in_grid]],
1027    uint tid [[thread_index_in_threadgroup]],
1028    uint simd_lane [[thread_index_in_simdgroup]],
1029    uint simd_id [[simdgroup_index_in_threadgroup]])
1030{
1031    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1032    uint token = tgid / row_tgs;
1033    uint tg_in_token = tgid % row_tgs;
1034    if (token >= num_tokens) return;
1035
1036    // Load this token's input vector into shared memory
1037    threadgroup float vec_tile[VEC_TILE_SIZE];
1038    device const float* input = inputs + token * cols;
1039    for (uint i = tid; i < cols; i += 256) {
1040        vec_tile[i] = input[i];
1041    }
1042    threadgroup_barrier(mem_flags::mem_threadgroup);
1043
1044    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1045    if (row_base >= rows) return;
1046
1047    uint base0 = row_base * cols;
1048    uint base1 = (row_base + 1) * cols;
1049    uint base2 = (row_base + 2) * cols;
1050    uint base3 = (row_base + 3) * cols;
1051    uint base4 = (row_base + 4) * cols;
1052    uint base5 = (row_base + 5) * cols;
1053    uint base6 = (row_base + 6) * cols;
1054    uint base7 = (row_base + 7) * cols;
1055
1056    uint cols_vec4 = cols & ~127u;
1057    float4 sum4_0 = float4(0.0f);
1058    float4 sum4_1 = float4(0.0f);
1059    float4 sum4_2 = float4(0.0f);
1060    float4 sum4_3 = float4(0.0f);
1061    float4 sum4_4 = float4(0.0f);
1062    float4 sum4_5 = float4(0.0f);
1063    float4 sum4_6 = float4(0.0f);
1064    float4 sum4_7 = float4(0.0f);
1065
1066    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1067        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1068        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1069        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1070        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1071        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1072        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1073        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1074        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1075        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1076    }
1077
1078    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1079    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1080    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1081    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1082    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1083    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1084    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1085    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1086
1087    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1088        float vv = vec_tile[j];
1089        sum0 += matrix[base0 + j] * vv;
1090        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1091        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1092        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1093        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1094        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1095        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1096        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1097    }
1098
1099    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1100    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1101    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1102    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1103
1104    device float* output = outputs + token * rows;
1105    if (simd_lane == 0) {
1106        if (row_base     < rows) output[row_base]     = sum0;
1107        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1108        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1109        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1110        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1111        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1112        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1113        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1114    }
1115}
1116
1117// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1118// Batched Q8_0 matrix-vector multiply for M input vectors.
1119// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1120kernel void matmul_vec_q8_batch(
1121    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1122    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1123    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1124    constant uint& num_tokens    [[buffer(3)]],  // M
1125    constant uint& rows          [[buffer(4)]],
1126    constant uint& cols          [[buffer(5)]],
1127    uint tgid [[threadgroup_position_in_grid]],
1128    uint tid [[thread_index_in_threadgroup]],
1129    uint simd_lane [[thread_index_in_simdgroup]],
1130    uint simd_id [[simdgroup_index_in_threadgroup]])
1131{
1132    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1133    uint token = tgid / row_tgs;
1134    uint tg_in_token = tgid % row_tgs;
1135    if (token >= num_tokens) return;
1136
1137    // Load this token's input vector into shared memory
1138    threadgroup float vec_tile[VEC_TILE_SIZE];
1139    device const float* input = inputs + token * cols;
1140    for (uint i = tid; i < cols; i += 256) {
1141        vec_tile[i] = input[i];
1142    }
1143    threadgroup_barrier(mem_flags::mem_threadgroup);
1144
1145    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1146    if (row_base >= rows) return;
1147
1148    uint blocks_per_row = cols / 32;
1149    uint row_bytes = blocks_per_row * 34;
1150
1151    device const uchar* r0 = matrix + row_base * row_bytes;
1152    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1153    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1154    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1155
1156    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1157
1158    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1159        uint bb = blk * 34;
1160        uint vb = blk * 32;
1161
1162        float sc0 = float(*(device const half*)(r0 + bb));
1163        float sc1 = float(*(device const half*)(r1 + bb));
1164        float sc2 = float(*(device const half*)(r2 + bb));
1165        float sc3 = float(*(device const half*)(r3 + bb));
1166
1167        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1168        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1169        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1170        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1171        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1172        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1173
1174        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1175        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1176        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1177        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1178        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1179        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1180        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1181        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1182
1183        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1184            short4 _s = short4(SHORT4); \
1185            char2 _a = as_type<char2>(_s.x); \
1186            char2 _b = as_type<char2>(_s.y); \
1187            char2 _c = as_type<char2>(_s.z); \
1188            char2 _d = as_type<char2>(_s.w); \
1189            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1190            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1191        }
1192
1193        float4 f0, f1;
1194        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1195
1196        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1197        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1198        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1199        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1200
1201        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1202        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1203        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1204        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1205
1206        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1207        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1208        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1209        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1210
1211        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1212        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1213        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1214        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1215
1216        #undef Q8_UNPACK8
1217
1218        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1219    }
1220
1221    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1222    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1223
1224    device float* output = outputs + token * rows;
1225    if (simd_lane == 0) {
1226        if (row_base     < rows) output[row_base]     = sum0;
1227        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1228        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1229        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1230    }
1231}
1232
1233// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1234// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1235// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1236// the Q8_0 weight blocks are fetched once from device memory and reused for
1237// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1238// per-token dispatch).
1239//
1240// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1241// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1242// with simd_sum.  Token vectors are read directly from device memory inside
1243// the block loop (not cached in shared memory) so intermediate_size up to
1244// 8192 fits without spilling threadgroup memory.
1245constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1246
1247kernel void matmul_q8_gemm_batch(
1248    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1249    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1250    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1251    constant uint& num_tokens    [[buffer(3)]],  // M
1252    constant uint& rows          [[buffer(4)]],
1253    constant uint& cols          [[buffer(5)]],
1254    uint2 tgid [[threadgroup_position_in_grid]],
1255    uint tid [[thread_index_in_threadgroup]],
1256    uint simd_lane [[thread_index_in_simdgroup]],
1257    uint simd_id [[simdgroup_index_in_threadgroup]])
1258{
1259    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1260    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1261    if (row_base >= rows || tok_base >= num_tokens) return;
1262
1263    // How many tokens in this tile are valid?
1264    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1265
1266    uint blocks_per_row = cols / 32;
1267    uint row_bytes = blocks_per_row * 34;
1268
1269    device const uchar* r0 = matrix + row_base * row_bytes;
1270    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1271    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1272    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1273
1274    // Accumulators: 4 tokens × 4 rows per simdgroup.
1275    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1276    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1277    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1278    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1279
1280    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1281        uint bb = blk * 34;
1282        uint vb = blk * 32;
1283
1284        // ── Load weight data ONCE per block (reused across all tokens) ──
1285        float sc0 = float(*(device const half*)(r0 + bb));
1286        float sc1 = float(*(device const half*)(r1 + bb));
1287        float sc2 = float(*(device const half*)(r2 + bb));
1288        float sc3 = float(*(device const half*)(r3 + bb));
1289
1290        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1291        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1292        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1293        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1294
1295        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1296            short4 _s = short4(SHORT4); \
1297            char2 _a = as_type<char2>(_s.x); \
1298            char2 _b = as_type<char2>(_s.y); \
1299            char2 _c = as_type<char2>(_s.z); \
1300            char2 _d = as_type<char2>(_s.w); \
1301            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1302            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1303        }
1304
1305        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1306        // registers for the duration of the block and are dotted against
1307        // every token's vector tile.
1308        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1309        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1310        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1311        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1312
1313        Q8_UNPACK8(d0[0], w0_0, w0_1);
1314        Q8_UNPACK8(d0[1], w0_2, w0_3);
1315        Q8_UNPACK8(d0[2], w0_4, w0_5);
1316        Q8_UNPACK8(d0[3], w0_6, w0_7);
1317
1318        Q8_UNPACK8(d1[0], w1_0, w1_1);
1319        Q8_UNPACK8(d1[1], w1_2, w1_3);
1320        Q8_UNPACK8(d1[2], w1_4, w1_5);
1321        Q8_UNPACK8(d1[3], w1_6, w1_7);
1322
1323        Q8_UNPACK8(d2[0], w2_0, w2_1);
1324        Q8_UNPACK8(d2[1], w2_2, w2_3);
1325        Q8_UNPACK8(d2[2], w2_4, w2_5);
1326        Q8_UNPACK8(d2[3], w2_6, w2_7);
1327
1328        Q8_UNPACK8(d3[0], w3_0, w3_1);
1329        Q8_UNPACK8(d3[1], w3_2, w3_3);
1330        Q8_UNPACK8(d3[2], w3_4, w3_5);
1331        Q8_UNPACK8(d3[3], w3_6, w3_7);
1332
1333        #undef Q8_UNPACK8
1334
1335        // ── For each token, read vector and accumulate against shared weights ──
1336        // Token 0 (always valid: tok_count >= 1).
1337        {
1338            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1339            float4 v0 = *(device const float4*)(a0);
1340            float4 v1 = *(device const float4*)(a0 + 4);
1341            float4 v2 = *(device const float4*)(a0 + 8);
1342            float4 v3 = *(device const float4*)(a0 + 12);
1343            float4 v4 = *(device const float4*)(a0 + 16);
1344            float4 v5 = *(device const float4*)(a0 + 20);
1345            float4 v6 = *(device const float4*)(a0 + 24);
1346            float4 v7 = *(device const float4*)(a0 + 28);
1347            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1348                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1349            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1350                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1351            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1352                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1353            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1354                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1355            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1356        }
1357        // Token 1
1358        if (tok_count > 1) {
1359            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1360            float4 v0 = *(device const float4*)(a1);
1361            float4 v1 = *(device const float4*)(a1 + 4);
1362            float4 v2 = *(device const float4*)(a1 + 8);
1363            float4 v3 = *(device const float4*)(a1 + 12);
1364            float4 v4 = *(device const float4*)(a1 + 16);
1365            float4 v5 = *(device const float4*)(a1 + 20);
1366            float4 v6 = *(device const float4*)(a1 + 24);
1367            float4 v7 = *(device const float4*)(a1 + 28);
1368            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1369                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1370            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1371                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1372            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1373                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1374            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1375                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1376            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1377        }
1378        // Token 2
1379        if (tok_count > 2) {
1380            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1381            float4 v0 = *(device const float4*)(a2);
1382            float4 v1 = *(device const float4*)(a2 + 4);
1383            float4 v2 = *(device const float4*)(a2 + 8);
1384            float4 v3 = *(device const float4*)(a2 + 12);
1385            float4 v4 = *(device const float4*)(a2 + 16);
1386            float4 v5 = *(device const float4*)(a2 + 20);
1387            float4 v6 = *(device const float4*)(a2 + 24);
1388            float4 v7 = *(device const float4*)(a2 + 28);
1389            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1390                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1391            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1392                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1393            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1394                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1395            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1396                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1397            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1398        }
1399        // Token 3
1400        if (tok_count > 3) {
1401            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1402            float4 v0 = *(device const float4*)(a3);
1403            float4 v1 = *(device const float4*)(a3 + 4);
1404            float4 v2 = *(device const float4*)(a3 + 8);
1405            float4 v3 = *(device const float4*)(a3 + 12);
1406            float4 v4 = *(device const float4*)(a3 + 16);
1407            float4 v5 = *(device const float4*)(a3 + 20);
1408            float4 v6 = *(device const float4*)(a3 + 24);
1409            float4 v7 = *(device const float4*)(a3 + 28);
1410            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1411                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1412            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1413                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1414            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1415                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1416            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1417                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1418            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1419        }
1420    }
1421
1422    // simdgroup reduction
1423    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1424    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1425    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1426    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1427
1428    if (simd_lane == 0) {
1429        device float* o0 = outputs + (tok_base + 0) * rows;
1430        if (row_base     < rows) o0[row_base]     = s00;
1431        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1432        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1433        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1434
1435        if (tok_count > 1) {
1436            device float* o1 = outputs + (tok_base + 1) * rows;
1437            if (row_base     < rows) o1[row_base]     = s10;
1438            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1439            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1440            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1441        }
1442        if (tok_count > 2) {
1443            device float* o2 = outputs + (tok_base + 2) * rows;
1444            if (row_base     < rows) o2[row_base]     = s20;
1445            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1446            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1447            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1448        }
1449        if (tok_count > 3) {
1450            device float* o3 = outputs + (tok_base + 3) * rows;
1451            if (row_base     < rows) o3[row_base]     = s30;
1452            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1453            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1454            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1455        }
1456    }
1457}
1458
1459// ── matmul_q8_mma ──────────────────────────────────────────────────────
1460// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1461// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1462// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1463// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1464//
1465// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1466// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1467// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1468// dequantized into threadgroup memory once per block and reused by all
1469// simdgroups in the tile.
1470//
1471// Assumptions (verified in the dispatch helper, falls back otherwise):
1472//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1473//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1474//   * num_tokens may be any value; partial row at the tile boundary is handled
1475//     via a scratch copy path.
1476constant constexpr uint MMA_TOK_TILE = 16;
1477constant constexpr uint MMA_ROW_TILE = 16;
1478
1479kernel void matmul_q8_mma(
1480    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1481    device const float* inputs   [[buffer(1)]],  // [M, cols]
1482    device float* outputs        [[buffer(2)]],  // [M, rows]
1483    constant uint& num_tokens    [[buffer(3)]],
1484    constant uint& rows          [[buffer(4)]],
1485    constant uint& cols          [[buffer(5)]],
1486    uint2 tgid [[threadgroup_position_in_grid]],
1487    uint tid [[thread_index_in_threadgroup]],
1488    uint simd_id [[simdgroup_index_in_threadgroup]])
1489{
1490    uint row_base = tgid.x * MMA_ROW_TILE;
1491    uint tok_base = tgid.y * MMA_TOK_TILE;
1492    if (row_base >= rows || tok_base >= num_tokens) return;
1493
1494    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1495    threadgroup float w_tile[MMA_ROW_TILE * 32];
1496    threadgroup float t_tile[MMA_TOK_TILE * 32];
1497
1498    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1499    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1500    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1501
1502    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1503
1504    uint blocks_per_row = cols / 32;
1505    uint row_bytes = blocks_per_row * 34;
1506
1507    for (uint blk = 0; blk < blocks_per_row; blk++) {
1508        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1509        // 512 floats / 128 threads = 4 floats per thread.
1510        {
1511            uint base = tid * 4;
1512            for (uint ii = 0; ii < 4; ii++) {
1513                uint idx = base + ii;
1514                uint r = idx / 32;
1515                uint k = idx % 32;
1516                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1517                float sc = float(*(device const half*)rp);
1518                int ival = int(*(device const int8_t*)(rp + 2 + k));
1519                w_tile[r * 32 + k] = float(ival) * sc;
1520            }
1521        }
1522
1523        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1524        {
1525            uint base = tid * 4;
1526            for (uint ii = 0; ii < 4; ii++) {
1527                uint idx = base + ii;
1528                uint m = idx / 32;
1529                uint k = idx % 32;
1530                uint tok = tok_base + m;
1531                t_tile[m * 32 + k] = (tok < num_tokens)
1532                    ? inputs[tok * cols + blk * 32 + k]
1533                    : 0.0f;
1534            }
1535        }
1536
1537        threadgroup_barrier(mem_flags::mem_threadgroup);
1538
1539        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1540        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1541        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1542        // C[m, r] += A[m, k] * B[k, r]
1543        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1544            simdgroup_matrix<float, 8, 8> A, B;
1545            simdgroup_load(A,
1546                t_tile + sg_tok_base * 32 + k_sub * 8,
1547                32,
1548                ulong2(0, 0),
1549                false);
1550            simdgroup_load(B,
1551                w_tile + sg_row_base * 32 + k_sub * 8,
1552                32,
1553                ulong2(0, 0),
1554                true);
1555            simdgroup_multiply_accumulate(C, A, B, C);
1556        }
1557
1558        threadgroup_barrier(mem_flags::mem_threadgroup);
1559    }
1560
1561    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1562    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1563    uint out_tok = tok_base + sg_tok_base;
1564    uint out_row = row_base + sg_row_base;
1565    bool full_tok = (out_tok + 8 <= num_tokens);
1566    if (full_tok) {
1567        // Fast path: entire 8×8 sub-tile is in-bounds.
1568        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1569    } else if (out_tok < num_tokens) {
1570        // Partial row at the last token tile: stage in per-simdgroup scratch
1571        // and scalar-copy the valid rows.
1572        threadgroup float scratch[4 * 64];
1573        simdgroup_store(C, scratch + simd_id * 64, 8);
1574        simdgroup_barrier(mem_flags::mem_threadgroup);
1575        uint lane = tid % 32;
1576        if (lane == 0) {
1577            uint valid = num_tokens - out_tok;  // 1..7
1578            for (uint m = 0; m < valid; m++) {
1579                device float* dst = outputs + (out_tok + m) * rows + out_row;
1580                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1581                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1582                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1583            }
1584        }
1585    }
1586}
1587
1588// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1589// Larger-tile variant of matmul_q8_mma for long-context prefill.
1590//
1591// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1592// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1593// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1594//
1595//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1596//     output sub-tiles (tok, row):
1597//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1598//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1599//
1600// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1601// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1602// kernel — and halves the number of threadgroups vs the 16×16 tile.
1603//
1604// Assumptions (verified in dispatch helper, fallback otherwise):
1605//   * cols % 32 == 0
1606//   * rows % 32 == 0
1607constant constexpr uint MMA32_TOK_TILE = 32;
1608constant constexpr uint MMA32_ROW_TILE = 32;
1609
1610kernel void matmul_q8_mma32(
1611    device const uchar* matrix   [[buffer(0)]],
1612    device const float* inputs   [[buffer(1)]],
1613    device float* outputs        [[buffer(2)]],
1614    constant uint& num_tokens    [[buffer(3)]],
1615    constant uint& rows          [[buffer(4)]],
1616    constant uint& cols          [[buffer(5)]],
1617    uint2 tgid [[threadgroup_position_in_grid]],
1618    uint tid [[thread_index_in_threadgroup]],
1619    uint simd_id [[simdgroup_index_in_threadgroup]])
1620{
1621    uint row_base = tgid.x * MMA32_ROW_TILE;
1622    uint tok_base = tgid.y * MMA32_TOK_TILE;
1623    if (row_base >= rows || tok_base >= num_tokens) return;
1624
1625    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1626    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1627    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1628
1629    uint sg_tok_idx  = simd_id / 2;      // 0..3
1630    uint sg_row_half = simd_id % 2;      // 0..1
1631    uint sg_tok_base = sg_tok_idx * 8;
1632    uint sg_row_base_a = sg_row_half * 16 + 0;
1633    uint sg_row_base_b = sg_row_half * 16 + 8;
1634
1635    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1636    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1637
1638    uint blocks_per_row = cols / 32;
1639    uint row_bytes = blocks_per_row * 34;
1640
1641    for (uint blk = 0; blk < blocks_per_row; blk++) {
1642        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1643        {
1644            uint base = tid * 4;
1645            for (uint ii = 0; ii < 4; ii++) {
1646                uint idx = base + ii;
1647                uint r = idx / 32;
1648                uint k = idx % 32;
1649                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1650                float sc = float(*(device const half*)rp);
1651                int ival = int(*(device const int8_t*)(rp + 2 + k));
1652                w_tile[r * 32 + k] = float(ival) * sc;
1653            }
1654        }
1655
1656        // Cooperative token tile load.
1657        {
1658            uint base = tid * 4;
1659            for (uint ii = 0; ii < 4; ii++) {
1660                uint idx = base + ii;
1661                uint m = idx / 32;
1662                uint k = idx % 32;
1663                uint tok = tok_base + m;
1664                t_tile[m * 32 + k] = (tok < num_tokens)
1665                    ? inputs[tok * cols + blk * 32 + k]
1666                    : 0.0f;
1667            }
1668        }
1669
1670        threadgroup_barrier(mem_flags::mem_threadgroup);
1671
1672        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1673        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1674            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1675            simdgroup_load(A,
1676                t_tile + sg_tok_base * 32 + k_sub * 8,
1677                32,
1678                ulong2(0, 0),
1679                false);
1680            simdgroup_load(B_a,
1681                w_tile + sg_row_base_a * 32 + k_sub * 8,
1682                32,
1683                ulong2(0, 0),
1684                true);
1685            simdgroup_load(B_b,
1686                w_tile + sg_row_base_b * 32 + k_sub * 8,
1687                32,
1688                ulong2(0, 0),
1689                true);
1690            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1691            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1692        }
1693
1694        threadgroup_barrier(mem_flags::mem_threadgroup);
1695    }
1696
1697    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1698    // (verified in dispatch), so full simdgroup_store is safe for the row
1699    // dimension; only the last token tile may be partial.
1700    uint out_tok = tok_base + sg_tok_base;
1701    uint out_row_a = row_base + sg_row_base_a;
1702    uint out_row_b = row_base + sg_row_base_b;
1703    bool full_tok = (out_tok + 8 <= num_tokens);
1704    if (full_tok) {
1705        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1706        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1707    } else if (out_tok < num_tokens) {
1708        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1709        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1710        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1711        simdgroup_barrier(mem_flags::mem_threadgroup);
1712        uint lane = tid % 32;
1713        if (lane == 0) {
1714            uint valid = num_tokens - out_tok;  // 1..7
1715            for (uint m = 0; m < valid; m++) {
1716                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1717                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1718                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1719                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1720                for (uint j = 0; j < 8; j++) {
1721                    dst_a[j] = src_a[j];
1722                    dst_b[j] = src_b[j];
1723                }
1724            }
1725        }
1726    }
1727}
1728
1729// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1730// FP16 threadgroup-tile variant of matmul_q8_mma32.
1731//
1732// Stores dequantized weights and token inputs as `half` in threadgroup
1733// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1734// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1735// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1736// representation preserves the full quantized dynamic range.  Token
1737// activations stay numerically safe because the subsequent
1738// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1739//
1740// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1741// accumulators each.  Primary win vs mma32 is occupancy at moderate
1742// prefill lengths where the GPU is wave-starved.
1743kernel void matmul_q8_mma32_h(
1744    device const uchar* matrix   [[buffer(0)]],
1745    device const float* inputs   [[buffer(1)]],
1746    device float* outputs        [[buffer(2)]],
1747    constant uint& num_tokens    [[buffer(3)]],
1748    constant uint& rows          [[buffer(4)]],
1749    constant uint& cols          [[buffer(5)]],
1750    uint2 tgid [[threadgroup_position_in_grid]],
1751    uint tid [[thread_index_in_threadgroup]],
1752    uint simd_id [[simdgroup_index_in_threadgroup]])
1753{
1754    uint row_base = tgid.x * MMA32_ROW_TILE;
1755    uint tok_base = tgid.y * MMA32_TOK_TILE;
1756    if (row_base >= rows || tok_base >= num_tokens) return;
1757
1758    // 32×32 half tiles — 2 KB each, 4 KB total.
1759    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1760    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1761
1762    uint sg_tok_idx  = simd_id / 2;
1763    uint sg_row_half = simd_id % 2;
1764    uint sg_tok_base = sg_tok_idx * 8;
1765    uint sg_row_base_a = sg_row_half * 16 + 0;
1766    uint sg_row_base_b = sg_row_half * 16 + 8;
1767
1768    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1769    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1770
1771    uint blocks_per_row = cols / 32;
1772    uint row_bytes = blocks_per_row * 34;
1773
1774    for (uint blk = 0; blk < blocks_per_row; blk++) {
1775        // Cooperative weight dequantization to FP16.
1776        {
1777            uint base = tid * 4;
1778            for (uint ii = 0; ii < 4; ii++) {
1779                uint idx = base + ii;
1780                uint r = idx / 32;
1781                uint k = idx % 32;
1782                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1783                float sc = float(*(device const half*)rp);
1784                int ival = int(*(device const int8_t*)(rp + 2 + k));
1785                w_tile[r * 32 + k] = half(float(ival) * sc);
1786            }
1787        }
1788
1789        // Cooperative token tile load (f32 → f16 narrowing).
1790        {
1791            uint base = tid * 4;
1792            for (uint ii = 0; ii < 4; ii++) {
1793                uint idx = base + ii;
1794                uint m = idx / 32;
1795                uint k = idx % 32;
1796                uint tok = tok_base + m;
1797                t_tile[m * 32 + k] = (tok < num_tokens)
1798                    ? half(inputs[tok * cols + blk * 32 + k])
1799                    : half(0);
1800            }
1801        }
1802
1803        threadgroup_barrier(mem_flags::mem_threadgroup);
1804
1805        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1806            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1807            simdgroup_load(A,
1808                t_tile + sg_tok_base * 32 + k_sub * 8,
1809                32,
1810                ulong2(0, 0),
1811                false);
1812            simdgroup_load(B_a,
1813                w_tile + sg_row_base_a * 32 + k_sub * 8,
1814                32,
1815                ulong2(0, 0),
1816                true);
1817            simdgroup_load(B_b,
1818                w_tile + sg_row_base_b * 32 + k_sub * 8,
1819                32,
1820                ulong2(0, 0),
1821                true);
1822            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1823            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1824        }
1825
1826        threadgroup_barrier(mem_flags::mem_threadgroup);
1827    }
1828
1829    uint out_tok = tok_base + sg_tok_base;
1830    uint out_row_a = row_base + sg_row_base_a;
1831    uint out_row_b = row_base + sg_row_base_b;
1832    bool full_tok = (out_tok + 8 <= num_tokens);
1833    if (full_tok) {
1834        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1835        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1836    } else if (out_tok < num_tokens) {
1837        threadgroup float scratch[8 * 2 * 64];
1838        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1839        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1840        simdgroup_barrier(mem_flags::mem_threadgroup);
1841        uint lane = tid % 32;
1842        if (lane == 0) {
1843            uint valid = num_tokens - out_tok;
1844            for (uint m = 0; m < valid; m++) {
1845                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1846                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1847                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1848                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1849                for (uint j = 0; j < 8; j++) {
1850                    dst_a[j] = src_a[j];
1851                    dst_b[j] = src_b[j];
1852                }
1853            }
1854        }
1855    }
1856}
1857
1858// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1859// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1860//
1861// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1862// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1863//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1864//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1865// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1866//
1867// Per K_sub iteration: load two A fragments and two B fragments, then run
1868// **four** MMA instructions reusing A_top with both B's and A_bot with
1869// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1870// 2-accumulator kernel and halves the thread count per threadgroup (128
1871// threads), which often improves occupancy on Apple GPUs where the
1872// concurrent-thread budget is the tighter limit than shared-memory size.
1873kernel void matmul_q8_mma32_h4(
1874    device const uchar* matrix   [[buffer(0)]],
1875    device const float* inputs   [[buffer(1)]],
1876    device float* outputs        [[buffer(2)]],
1877    constant uint& num_tokens    [[buffer(3)]],
1878    constant uint& rows          [[buffer(4)]],
1879    constant uint& cols          [[buffer(5)]],
1880    uint2 tgid [[threadgroup_position_in_grid]],
1881    uint tid [[thread_index_in_threadgroup]],
1882    uint simd_id [[simdgroup_index_in_threadgroup]])
1883{
1884    uint row_base = tgid.x * MMA32_ROW_TILE;
1885    uint tok_base = tgid.y * MMA32_TOK_TILE;
1886    if (row_base >= rows || tok_base >= num_tokens) return;
1887
1888    // 32×32 FP16 tiles, 4 KB total.
1889    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1890    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1891
1892    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1893    uint sg_tok_q = simd_id / 2;   // 0..1
1894    uint sg_row_q = simd_id % 2;   // 0..1
1895    uint sg_tok_base = sg_tok_q * 16;
1896    uint sg_row_base = sg_row_q * 16;
1897
1898    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1899    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1900    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1901    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1902
1903    uint blocks_per_row = cols / 32;
1904    uint row_bytes = blocks_per_row * 34;
1905
1906    for (uint blk = 0; blk < blocks_per_row; blk++) {
1907        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1908        {
1909            uint base = tid * 8;
1910            for (uint ii = 0; ii < 8; ii++) {
1911                uint idx = base + ii;
1912                uint r = idx / 32;
1913                uint k = idx % 32;
1914                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1915                float sc = float(*(device const half*)rp);
1916                int ival = int(*(device const int8_t*)(rp + 2 + k));
1917                w_tile[r * 32 + k] = half(float(ival) * sc);
1918            }
1919        }
1920
1921        // Cooperative token tile load.
1922        {
1923            uint base = tid * 8;
1924            for (uint ii = 0; ii < 8; ii++) {
1925                uint idx = base + ii;
1926                uint m = idx / 32;
1927                uint k = idx % 32;
1928                uint tok = tok_base + m;
1929                t_tile[m * 32 + k] = (tok < num_tokens)
1930                    ? half(inputs[tok * cols + blk * 32 + k])
1931                    : half(0);
1932            }
1933        }
1934
1935        threadgroup_barrier(mem_flags::mem_threadgroup);
1936
1937        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1938        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1939            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1940            simdgroup_load(A_top,
1941                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1942                32,
1943                ulong2(0, 0),
1944                false);
1945            simdgroup_load(A_bot,
1946                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1947                32,
1948                ulong2(0, 0),
1949                false);
1950            simdgroup_load(B_lo,
1951                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1952                32,
1953                ulong2(0, 0),
1954                true);
1955            simdgroup_load(B_hi,
1956                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1957                32,
1958                ulong2(0, 0),
1959                true);
1960            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1961            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1962            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1963            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1964        }
1965
1966        threadgroup_barrier(mem_flags::mem_threadgroup);
1967    }
1968
1969    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
1970    uint out_tok_top = tok_base + sg_tok_base + 0;
1971    uint out_tok_bot = tok_base + sg_tok_base + 8;
1972    uint out_row_lo  = row_base + sg_row_base + 0;
1973    uint out_row_hi  = row_base + sg_row_base + 8;
1974    bool full = (out_tok_bot + 8 <= num_tokens);
1975    if (full) {
1976        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1977        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1978        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1979        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1980    } else {
1981        // Partial-token fallback via per-simdgroup scratch.
1982        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
1983        uint sg_base = simd_id * 256;
1984        simdgroup_store(C_00, scratch + sg_base +   0, 8);
1985        simdgroup_store(C_01, scratch + sg_base +  64, 8);
1986        simdgroup_store(C_10, scratch + sg_base + 128, 8);
1987        simdgroup_store(C_11, scratch + sg_base + 192, 8);
1988        simdgroup_barrier(mem_flags::mem_threadgroup);
1989        uint lane = tid % 32;
1990        if (lane == 0) {
1991            for (uint m = 0; m < 8; m++) {
1992                uint t_top = out_tok_top + m;
1993                if (t_top < num_tokens) {
1994                    device float* dst0 = outputs + t_top * rows + out_row_lo;
1995                    device float* dst1 = outputs + t_top * rows + out_row_hi;
1996                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
1997                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
1998                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
1999                }
2000                uint t_bot = out_tok_bot + m;
2001                if (t_bot < num_tokens) {
2002                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2003                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2004                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2005                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2006                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2007                }
2008            }
2009        }
2010    }
2011}
2012
2013// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2014// All-half MMA variant of matmul_q8_mma32_h4.
2015//
2016// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2017// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2018// rate (dual-issue FMA), so if Q8_0 precision holds through half
2019// accumulation this kernel can double the effective FLOP throughput on
2020// matmul-bound prefill.
2021//
2022// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2023// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2024// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2025// precision on extreme values, but the inputs are already quantized so the
2026// per-product error floor is higher than the half-precision rounding error.
2027// We verify correctness on 135M / 1B / 3B before enabling.
2028kernel void matmul_q8_mma32_hh4(
2029    device const uchar* matrix   [[buffer(0)]],
2030    device const float* inputs   [[buffer(1)]],
2031    device float* outputs        [[buffer(2)]],
2032    constant uint& num_tokens    [[buffer(3)]],
2033    constant uint& rows          [[buffer(4)]],
2034    constant uint& cols          [[buffer(5)]],
2035    uint2 tgid [[threadgroup_position_in_grid]],
2036    uint tid [[thread_index_in_threadgroup]],
2037    uint simd_id [[simdgroup_index_in_threadgroup]])
2038{
2039    uint row_base = tgid.x * MMA32_ROW_TILE;
2040    uint tok_base = tgid.y * MMA32_TOK_TILE;
2041    if (row_base >= rows || tok_base >= num_tokens) return;
2042
2043    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2044    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2045
2046    uint sg_tok_q = simd_id / 2;
2047    uint sg_row_q = simd_id % 2;
2048    uint sg_tok_base = sg_tok_q * 16;
2049    uint sg_row_base = sg_row_q * 16;
2050
2051    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2052    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2053    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2054    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2055
2056    uint blocks_per_row = cols / 32;
2057    uint row_bytes = blocks_per_row * 34;
2058
2059    for (uint blk = 0; blk < blocks_per_row; blk++) {
2060        {
2061            uint base = tid * 8;
2062            for (uint ii = 0; ii < 8; ii++) {
2063                uint idx = base + ii;
2064                uint r = idx / 32;
2065                uint k = idx % 32;
2066                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2067                float sc = float(*(device const half*)rp);
2068                int ival = int(*(device const int8_t*)(rp + 2 + k));
2069                w_tile[r * 32 + k] = half(float(ival) * sc);
2070            }
2071        }
2072        {
2073            uint base = tid * 8;
2074            for (uint ii = 0; ii < 8; ii++) {
2075                uint idx = base + ii;
2076                uint m = idx / 32;
2077                uint k = idx % 32;
2078                uint tok = tok_base + m;
2079                t_tile[m * 32 + k] = (tok < num_tokens)
2080                    ? half(inputs[tok * cols + blk * 32 + k])
2081                    : half(0);
2082            }
2083        }
2084        threadgroup_barrier(mem_flags::mem_threadgroup);
2085
2086        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2087            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2088            simdgroup_load(A_top,
2089                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2090                32, ulong2(0, 0), false);
2091            simdgroup_load(A_bot,
2092                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2093                32, ulong2(0, 0), false);
2094            simdgroup_load(B_lo,
2095                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2096                32, ulong2(0, 0), true);
2097            simdgroup_load(B_hi,
2098                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2099                32, ulong2(0, 0), true);
2100            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2101            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2102            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2103            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2104        }
2105
2106        threadgroup_barrier(mem_flags::mem_threadgroup);
2107    }
2108
2109    // Store half accumulators via scratch (must widen to f32 for device output).
2110    uint out_tok_top = tok_base + sg_tok_base + 0;
2111    uint out_tok_bot = tok_base + sg_tok_base + 8;
2112    uint out_row_lo  = row_base + sg_row_base + 0;
2113    uint out_row_hi  = row_base + sg_row_base + 8;
2114
2115    threadgroup half scratch[4 * 4 * 64];
2116    uint sg_base = simd_id * 256;
2117    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2118    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2119    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2120    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2121    simdgroup_barrier(mem_flags::mem_threadgroup);
2122    uint lane = tid % 32;
2123    if (lane == 0) {
2124        for (uint m = 0; m < 8; m++) {
2125            uint t_top = out_tok_top + m;
2126            if (t_top < num_tokens) {
2127                device float* dst0 = outputs + t_top * rows + out_row_lo;
2128                device float* dst1 = outputs + t_top * rows + out_row_hi;
2129                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2130                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2131                for (uint j = 0; j < 8; j++) {
2132                    dst0[j] = float(src0[j]);
2133                    dst1[j] = float(src1[j]);
2134                }
2135            }
2136            uint t_bot = out_tok_bot + m;
2137            if (t_bot < num_tokens) {
2138                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2139                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2140                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2141                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2142                for (uint j = 0; j < 8; j++) {
2143                    dst2[j] = float(src2[j]);
2144                    dst3[j] = float(src3[j]);
2145                }
2146            }
2147        }
2148    }
2149}
2150
2151// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2152// Batched Q4_0 matrix-vector multiply for M input vectors.
2153// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2154kernel void matmul_vec_q4_batch(
2155    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2156    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2157    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2158    constant uint& num_tokens    [[buffer(3)]],  // M
2159    constant uint& rows          [[buffer(4)]],
2160    constant uint& cols          [[buffer(5)]],
2161    uint tgid [[threadgroup_position_in_grid]],
2162    uint tid [[thread_index_in_threadgroup]],
2163    uint simd_lane [[thread_index_in_simdgroup]],
2164    uint simd_id [[simdgroup_index_in_threadgroup]])
2165{
2166    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2167    uint token = tgid / row_tgs;
2168    uint tg_in_token = tgid % row_tgs;
2169    if (token >= num_tokens) return;
2170
2171    threadgroup float vec_tile[VEC_TILE_SIZE];
2172    device const float* input = inputs + token * cols;
2173    for (uint i = tid; i < cols; i += 256) {
2174        vec_tile[i] = input[i];
2175    }
2176    threadgroup_barrier(mem_flags::mem_threadgroup);
2177
2178    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2179    if (row_base >= rows) return;
2180
2181    uint blocks_per_row = cols / 32;
2182    uint row_bytes = blocks_per_row * 18;
2183
2184    device const uchar* r0 = matrix + row_base * row_bytes;
2185    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2186    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2187    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2188
2189    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2190
2191    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2192        uint bb = blk * 18;
2193        uint vb = blk * 32;
2194
2195        float sc0 = float(*(device const half*)(r0 + bb));
2196        float sc1 = float(*(device const half*)(r1 + bb));
2197        float sc2 = float(*(device const half*)(r2 + bb));
2198        float sc3 = float(*(device const half*)(r3 + bb));
2199
2200        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2201        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2202        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2203        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2204
2205        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2206        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2207        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2208        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2209        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2210        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2211        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2212        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2213
2214        float bd0=0, bd1=0, bd2=0, bd3=0;
2215        uchar4 b;
2216
2217        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2218        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2219        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2220        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2221
2222        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2223        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2224        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2225        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2226
2227        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2228        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2229        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2230        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2231
2232        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2233        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2234        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2235        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2236
2237        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2238    }
2239
2240    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2241    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2242
2243    device float* output = outputs + token * rows;
2244    if (simd_lane == 0) {
2245        if (row_base     < rows) output[row_base]     = sum0;
2246        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2247        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2248        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2249    }
2250}
2251
2252// ── copy_kv_batch ─────────────────────────────────────────────────────
2253// Copy K or V from a strided batch QKV buffer to the KV cache.
2254// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2255// dst layout: contiguous [max_seq, kv_dim] cache.
2256kernel void copy_kv_batch(
2257    device const float* src  [[buffer(0)]],  // batch QKV buffer
2258    device float* dst        [[buffer(1)]],  // KV cache
2259    constant uint& M         [[buffer(2)]],  // num tokens in batch
2260    constant uint& kv_dim    [[buffer(3)]],  // floats per KV vector
2261    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2262    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2263    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2264    uint id [[thread_position_in_grid]])
2265{
2266    uint total = M * kv_dim;
2267    if (id >= total) return;
2268    uint token = id / kv_dim;
2269    uint d = id % kv_dim;
2270    uint dst_off = (base_pos + token) * kv_dim + d;
2271    uint src_off = token * src_stride + src_offset + d;
2272    dst[dst_off] = src[src_off];
2273}
2274
2275// ── attention_batch ───────────────────────────────────────────────────
2276// Batched causal attention for prefill. Processes M tokens in one dispatch.
2277// Each threadgroup handles one (token, head) pair.
2278// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2279// Causal masking: token i can only attend to positions 0..base_pos+i.
2280kernel void attention_batch(
2281    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2282    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2283    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2284    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2285    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2286    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2287    constant uint& num_heads         [[buffer(6)]],
2288    constant uint& num_kv_heads      [[buffer(7)]],
2289    constant uint& head_dim          [[buffer(8)]],
2290    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2291    uint tgid [[threadgroup_position_in_grid]],
2292    uint tid [[thread_index_in_threadgroup]],
2293    uint simd_lane [[thread_index_in_simdgroup]],
2294    uint simd_id [[simdgroup_index_in_threadgroup]])
2295{
2296    // Grid: M * num_heads threadgroups
2297    uint token_idx = tgid / num_heads;
2298    uint head = tgid % num_heads;
2299    if (token_idx >= M) return;
2300
2301    uint kv_head = head / (num_heads / num_kv_heads);
2302    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2303
2304    // Q offset uses strided layout (from batch QKV buffer)
2305    uint q_off = token_idx * q_stride + head * head_dim;
2306    // Output is contiguous [M, num_heads * head_dim]
2307    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2308
2309    // Shared memory for attention scores
2310    threadgroup float scores[2048];
2311
2312    // Step 1: Q * K^T with simdgroup reduction
2313    for (uint s = simd_id; s < seq_len; s += 8) {
2314        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2315        float dot = 0.0;
2316        for (uint d = simd_lane; d < head_dim; d += 32) {
2317            dot += q_batch[q_off + d] * k_cache[k_off + d];
2318        }
2319        dot = simd_sum(dot);
2320        if (simd_lane == 0) {
2321            scores[s] = dot * fast::rsqrt(float(head_dim));
2322        }
2323    }
2324    threadgroup_barrier(mem_flags::mem_threadgroup);
2325
2326    // Step 2: Softmax (cooperative)
2327    float local_max = -INFINITY;
2328    for (uint s = tid; s < seq_len; s += 256) {
2329        local_max = max(local_max, scores[s]);
2330    }
2331    local_max = simd_max(local_max);
2332    threadgroup float shared_max[8];
2333    if (simd_lane == 0) shared_max[simd_id] = local_max;
2334    threadgroup_barrier(mem_flags::mem_threadgroup);
2335    if (tid == 0) {
2336        float m = shared_max[0];
2337        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2338        shared_max[0] = m;
2339    }
2340    threadgroup_barrier(mem_flags::mem_threadgroup);
2341    float max_val = shared_max[0];
2342
2343    float local_sum = 0.0;
2344    for (uint s = tid; s < seq_len; s += 256) {
2345        scores[s] = fast::exp(scores[s] - max_val);
2346        local_sum += scores[s];
2347    }
2348    local_sum = simd_sum(local_sum);
2349    threadgroup float shared_sum[8];
2350    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2351    threadgroup_barrier(mem_flags::mem_threadgroup);
2352    if (tid == 0) {
2353        float total = 0.0;
2354        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2355        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2356    }
2357    threadgroup_barrier(mem_flags::mem_threadgroup);
2358    float inv_sum = shared_sum[0];
2359    for (uint s = tid; s < seq_len; s += 256) {
2360        scores[s] *= inv_sum;
2361    }
2362    threadgroup_barrier(mem_flags::mem_threadgroup);
2363
2364    // Step 3: scores * V using float4 vectorized loads
2365    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2366    // This is much better than the scalar version where only 64 of 256 threads are active.
2367    uint v_stride = num_kv_heads * head_dim;
2368    uint head_dim4 = head_dim / 4;
2369    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2370        uint d = d4 * 4;
2371        float4 acc = float4(0.0);
2372        uint v_base = kv_head * head_dim + d;
2373        uint seq_len4 = seq_len & ~3u;
2374        for (uint s = 0; s < seq_len4; s += 4) {
2375            float sc0 = scores[s];
2376            float sc1 = scores[s + 1];
2377            float sc2 = scores[s + 2];
2378            float sc3 = scores[s + 3];
2379            acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
2380                 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
2381                 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
2382                 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
2383        }
2384        for (uint s = seq_len4; s < seq_len; s++) {
2385            acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
2386        }
2387        *(device float4*)(output_batch + out_off + d) = acc;
2388    }
2389    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2390    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2391        float acc = 0.0;
2392        uint v_base = kv_head * head_dim + d;
2393        for (uint s = 0; s < seq_len; s++) {
2394            acc += scores[s] * v_cache[s * v_stride + v_base];
2395        }
2396        output_batch[out_off + d] = acc;
2397    }
2398}
2399
2400// ── rope_qk_batch ─────────────────────────────────────────────────────
2401// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2402// launch + memory barrier per layer. Both Q and K live in the same
2403// qkv_data buffer at different offsets within each token's row.
2404// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2405kernel void rope_qk_batch(
2406    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2407    constant uint& M                 [[buffer(1)]],   // num tokens
2408    constant uint& base_pos          [[buffer(2)]],   // starting position
2409    constant uint& num_q_heads       [[buffer(3)]],
2410    constant uint& num_kv_heads      [[buffer(4)]],
2411    constant uint& head_dim          [[buffer(5)]],
2412    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2413    constant float& theta            [[buffer(7)]],
2414    uint id [[thread_position_in_grid]])
2415{
2416    uint half_dim = head_dim / 2;
2417    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2418    uint token = id / total_pairs;
2419    uint pair = id % total_pairs;
2420    if (token >= M) return;
2421
2422    uint pos = base_pos + token;
2423    uint q_pairs = num_q_heads * half_dim;
2424
2425    uint h, i, offset;
2426    if (pair < q_pairs) {
2427        // Q head
2428        h = pair / half_dim;
2429        i = pair % half_dim;
2430        offset = token * qkv_stride + h * head_dim + i * 2;
2431    } else {
2432        // K head
2433        uint kp = pair - q_pairs;
2434        h = kp / half_dim;
2435        i = kp % half_dim;
2436        uint k_start = num_q_heads * head_dim;
2437        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2438    }
2439
2440    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2441    float angle = float(pos) * freq;
2442    float cos_val = cos(angle);
2443    float sin_val = sin(angle);
2444
2445    float x0 = qkv_data[offset];
2446    float x1 = qkv_data[offset + 1];
2447    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2448    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2449}
2450
2451// ── copy_kv_both_batch ────────────────────────────────────────────────
2452// Fused K+V cache copy in a single dispatch: copies both K and V from
2453// the strided batch QKV buffer to their respective KV cache buffers.
2454// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2455kernel void copy_kv_both_batch(
2456    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride]
2457    device float* k_dst        [[buffer(1)]],  // K cache [max_seq, kv_dim]
2458    device float* v_dst        [[buffer(2)]],  // V cache [max_seq, kv_dim]
2459    constant uint& M           [[buffer(3)]],  // num tokens in batch
2460    constant uint& kv_dim      [[buffer(4)]],  // floats per KV vector
2461    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2462    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2463    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2464    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2465    uint id [[thread_position_in_grid]])
2466{
2467    // Total elements = M * kv_dim * 2 (K + V)
2468    uint total_kv = M * kv_dim;
2469    if (id >= total_kv * 2) return;
2470
2471    uint is_v = id / total_kv;        // 0 = K, 1 = V
2472    uint local_id = id % total_kv;
2473    uint token = local_id / kv_dim;
2474    uint d = local_id % kv_dim;
2475
2476    uint dst_off = (base_pos + token) * kv_dim + d;
2477    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2478
2479    if (is_v) {
2480        v_dst[dst_off] = src[src_off];
2481    } else {
2482        k_dst[dst_off] = src[src_off];
2483    }
2484}
2485"#
2486    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2487}
2488
2489// ---------------------------------------------------------------------------
2490// model.rs generation
2491// ---------------------------------------------------------------------------
2492
2493fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2494    let mut code = String::with_capacity(48 * 1024);
2495    emit_model_header(&mut code, config)?;
2496    emit_metal_model_struct(&mut code, config)?;
2497    emit_layer_buffers_struct(&mut code)?;
2498    emit_metal_model_impl(&mut code, config)?;
2499    emit_helper_functions(&mut code)?;
2500    Ok(code)
2501}
2502
2503fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2504    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2505    writeln!(
2506        code,
2507        "//! Model: {} ({} layers, hidden={})",
2508        config.architecture, config.num_layers, config.hidden_size
2509    )?;
2510    writeln!(code, "//!")?;
2511    writeln!(
2512        code,
2513        "//! Uses native Metal compute pipelines via the metal crate."
2514    )?;
2515    writeln!(code)?;
2516    writeln!(code, "#![allow(dead_code)]")?;
2517    writeln!(code)?;
2518    writeln!(code, "use metal::*;")?;
2519    writeln!(code, "#[allow(unused_imports)]")?;
2520    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2521    writeln!(code, "use std::mem;")?;
2522    writeln!(code)?;
2523
2524    // Model constants
2525    writeln!(
2526        code,
2527        "// ── Model constants ──────────────────────────────────"
2528    )?;
2529    writeln!(
2530        code,
2531        "pub const HIDDEN_SIZE: usize = {};",
2532        config.hidden_size
2533    )?;
2534    writeln!(
2535        code,
2536        "pub const INTERMEDIATE_SIZE: usize = {};",
2537        config.intermediate_size
2538    )?;
2539    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2540    writeln!(
2541        code,
2542        "pub const NUM_HEADS: usize = {};",
2543        config.num_attention_heads
2544    )?;
2545    writeln!(
2546        code,
2547        "pub const NUM_KV_HEADS: usize = {};",
2548        config.num_kv_heads
2549    )?;
2550    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
2551    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
2552    let effective_seq_len = config.max_seq_len.min(4096);
2553    writeln!(
2554        code,
2555        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
2556        effective_seq_len, config.max_seq_len
2557    )?;
2558    writeln!(
2559        code,
2560        "pub const RMS_NORM_EPS: f32 = {:e};",
2561        config.rms_norm_eps
2562    )?;
2563    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
2564    writeln!(
2565        code,
2566        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
2567    )?;
2568    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
2569    writeln!(code)?;
2570
2571    Ok(())
2572}
2573
2574fn emit_metal_model_struct(
2575    code: &mut String,
2576    _config: &ModelConfig,
2577) -> Result<(), MetalCodegenError> {
2578    writeln!(
2579        code,
2580        "// ── MetalModel ──────────────────────────────────────────"
2581    )?;
2582    writeln!(code)?;
2583    writeln!(
2584        code,
2585        "/// Metal-accelerated transformer model for Apple Silicon."
2586    )?;
2587    writeln!(code, "///")?;
2588    writeln!(
2589        code,
2590        "/// Uses unified memory for zero-copy weight access and native Metal"
2591    )?;
2592    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
2593    writeln!(code, "pub struct MetalModel {{")?;
2594    writeln!(code, "    device: Device,")?;
2595    writeln!(code, "    queue: CommandQueue,")?;
2596    writeln!(code)?;
2597    writeln!(code, "    // ── Compute pipelines ──")?;
2598    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
2599    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
2600    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
2601    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
2602    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
2603    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
2604    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
2605    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
2606    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
2607    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
2608    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
2609    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
2610    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
2611    writeln!(code)?;
2612    writeln!(code, "    // ── Batched prefill pipelines ──")?;
2613    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
2614    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
2615    writeln!(
2616        code,
2617        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
2618    )?;
2619    writeln!(
2620        code,
2621        "    matmul_q8_mma_pipeline: ComputePipelineState,"
2622    )?;
2623    writeln!(
2624        code,
2625        "    matmul_q8_mma32_pipeline: ComputePipelineState,"
2626    )?;
2627    writeln!(
2628        code,
2629        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
2630    )?;
2631    writeln!(
2632        code,
2633        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
2634    )?;
2635    writeln!(
2636        code,
2637        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
2638    )?;
2639    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
2640    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
2641    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
2642    writeln!(
2643        code,
2644        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
2645    )?;
2646    writeln!(
2647        code,
2648        "    add_inplace_batch_pipeline: ComputePipelineState,"
2649    )?;
2650    writeln!(
2651        code,
2652        "    copy_embedding_batch_pipeline: ComputePipelineState,"
2653    )?;
2654    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
2655    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
2656    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
2657    writeln!(
2658        code,
2659        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
2660    )?;
2661    writeln!(code)?;
2662    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
2663    writeln!(
2664        code,
2665        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
2666    )?;
2667    writeln!(code, "    embed_buf: Buffer,")?;
2668    writeln!(code)?;
2669    writeln!(code, "    /// Per-layer weight buffers")?;
2670    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
2671    writeln!(code)?;
2672    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
2673    writeln!(code, "    norm_buf: Buffer,")?;
2674    writeln!(code)?;
2675    writeln!(
2676        code,
2677        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
2678    )?;
2679    writeln!(code, "    lm_head_buf: Buffer,")?;
2680    writeln!(code)?;
2681    writeln!(
2682        code,
2683        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
2684    )?;
2685    writeln!(code, "    hidden_buf: Buffer,")?;
2686    writeln!(code, "    residual_buf: Buffer,")?;
2687    writeln!(code, "    normed_buf: Buffer,")?;
2688    writeln!(
2689        code,
2690        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
2691    )?;
2692    writeln!(code, "    qkv_buf: Buffer,")?;
2693    writeln!(code, "    attn_out_buf: Buffer,")?;
2694    writeln!(code, "    attn_proj_buf: Buffer,")?;
2695    writeln!(
2696        code,
2697        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
2698    )?;
2699    writeln!(code, "    gate_up_buf: Buffer,")?;
2700    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
2701    writeln!(code, "    ffn_out_buf: Buffer,")?;
2702    writeln!(code, "    add_tmp_buf: Buffer,")?;
2703    writeln!(code, "    logits_buf: Buffer,")?;
2704    writeln!(code)?;
2705    writeln!(code, "    // ── Batched prefill working buffers ──")?;
2706    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
2707    writeln!(code, "    batch_hidden_buf: Buffer,")?;
2708    writeln!(
2709        code,
2710        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
2711    )?;
2712    writeln!(code, "    batch_residual_buf: Buffer,")?;
2713    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
2714    writeln!(code, "    batch_qkv_buf: Buffer,")?;
2715    writeln!(
2716        code,
2717        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
2718    )?;
2719    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
2720    writeln!(
2721        code,
2722        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
2723    )?;
2724    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
2725    writeln!(
2726        code,
2727        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
2728    )?;
2729    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
2730    writeln!(
2731        code,
2732        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
2733    )?;
2734    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
2735    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
2736    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
2737    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
2738    writeln!(code, "    batch_tokens_buf: Buffer,")?;
2739    writeln!(code, "    /// Positions buffer for batch RoPE")?;
2740    writeln!(code, "    batch_positions_buf: Buffer,")?;
2741    writeln!(code)?;
2742    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
2743    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
2744    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
2745    writeln!(code)?;
2746    writeln!(code, "    // ── Inference state ──")?;
2747    writeln!(code, "    pos: usize,")?;
2748    writeln!(code)?;
2749    writeln!(
2750        code,
2751        "    /// Previous command buffer for double-buffered prefill."
2752    )?;
2753    writeln!(
2754        code,
2755        "    /// While the GPU executes token N, the CPU can encode token N+1."
2756    )?;
2757    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
2758    writeln!(code, "}}")?;
2759    writeln!(code)?;
2760
2761    Ok(())
2762}
2763
2764fn emit_layer_buffers_struct(code: &mut String) -> Result<(), MetalCodegenError> {
2765    writeln!(
2766        code,
2767        "/// Per-layer weight buffers for attention and FFN projections."
2768    )?;
2769    writeln!(code, "struct LayerBuffers {{")?;
2770    writeln!(code, "    attn_norm: Buffer,")?;
2771    writeln!(
2772        code,
2773        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
2774    )?;
2775    writeln!(code, "    qkv_weight: Buffer,")?;
2776    writeln!(code, "    o_weight: Buffer,")?;
2777    writeln!(code, "    ffn_norm: Buffer,")?;
2778    writeln!(
2779        code,
2780        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
2781    )?;
2782    writeln!(code, "    gate_up_weight: Buffer,")?;
2783    writeln!(code, "    down_weight: Buffer,")?;
2784    writeln!(code, "}}")?;
2785    writeln!(code)?;
2786
2787    Ok(())
2788}
2789
2790fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2791    let hidden = config.hidden_size;
2792    let intermediate = config.intermediate_size;
2793    let _num_layers = config.num_layers;
2794    let num_heads = config.num_attention_heads;
2795    let num_kv_heads = config.num_kv_heads;
2796    let head_dim = config.head_dim;
2797    let vocab = config.vocab_size;
2798    let effective_seq_len = config.max_seq_len.min(4096);
2799    let is_q8 = config.dtype == DType::Q8_0;
2800    let is_q4 = config.dtype == DType::Q4_0;
2801    let kv_dim = num_kv_heads * head_dim;
2802
2803    writeln!(code, "impl MetalModel {{")?;
2804
2805    // ── new() ──
2806    writeln!(
2807        code,
2808        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
2809    )?;
2810    writeln!(code, "    ///")?;
2811    writeln!(
2812        code,
2813        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
2814    )?;
2815    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
2816    writeln!(
2817        code,
2818        "        let device = Device::system_default().expect(\"no Metal device found\");"
2819    )?;
2820    writeln!(code, "        let queue = device.new_command_queue();")?;
2821    writeln!(code)?;
2822
2823    // Compile shaders
2824    writeln!(
2825        code,
2826        "        // Compile Metal shaders from embedded source"
2827    )?;
2828    writeln!(
2829        code,
2830        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
2831    )?;
2832    writeln!(
2833        code,
2834        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
2835    )?;
2836    writeln!(
2837        code,
2838        "            .expect(\"failed to compile Metal shaders\");"
2839    )?;
2840    writeln!(code)?;
2841
2842    // Create compute pipelines
2843    writeln!(code, "        // Create compute pipelines")?;
2844    for (var, fn_name) in [
2845        ("matmul_pipeline", "matmul_vec"),
2846        ("matmul_q8_pipeline", "matmul_vec_q8"),
2847        ("matmul_q4_pipeline", "matmul_vec_q4"),
2848        ("rms_norm_pipeline", "rms_norm"),
2849        ("rope_pipeline", "rope"),
2850        ("softmax_pipeline", "softmax"),
2851        ("silu_mul_pipeline", "silu_mul"),
2852        ("silu_mul_fused_pipeline", "silu_mul_fused"),
2853        ("add_pipeline", "elementwise_add"),
2854        ("attention_pipeline", "attention"),
2855        ("add_inplace_pipeline", "add_inplace"),
2856        ("copy_pipeline", "copy_buffer"),
2857        ("copy_offset_pipeline", "copy_offset"),
2858        ("matmul_batch_pipeline", "matmul_vec_batch"),
2859        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
2860        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
2861        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
2862        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
2863        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
2864        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
2865        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
2866        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
2867        ("rms_norm_batch_pipeline", "rms_norm_batch"),
2868        ("rope_batch_pipeline", "rope_batch"),
2869        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
2870        ("add_inplace_batch_pipeline", "add_inplace_batch"),
2871        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
2872        ("attention_batch_pipeline", "attention_batch"),
2873        ("copy_kv_batch_pipeline", "copy_kv_batch"),
2874        ("rope_qk_batch_pipeline", "rope_qk_batch"),
2875        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
2876    ] {
2877        writeln!(
2878            code,
2879            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
2880        )?;
2881    }
2882    writeln!(code)?;
2883
2884    // Weight loading
2885    writeln!(
2886        code,
2887        "        // Load weights into Metal shared-memory buffers"
2888    )?;
2889    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
2890    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
2891    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
2892    writeln!(code)?;
2893    writeln!(
2894        code,
2895        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
2896    )?;
2897    writeln!(code)?;
2898    writeln!(
2899        code,
2900        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
2901    )?;
2902    writeln!(
2903        code,
2904        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
2905    )?;
2906    writeln!(code, "            let byte_len = n * f32_size;")?;
2907    writeln!(code, "            let cur = cursor.get();")?;
2908    writeln!(
2909        code,
2910        "            let data = &weights[cur..cur + byte_len];"
2911    )?;
2912    writeln!(code, "            cursor.set(cur + byte_len);")?;
2913    writeln!(code, "            device.new_buffer_with_data(")?;
2914    writeln!(code, "                data.as_ptr() as *const _,")?;
2915    writeln!(code, "                byte_len as u64,")?;
2916    writeln!(
2917        code,
2918        "                MTLResourceOptions::StorageModeShared,"
2919    )?;
2920    writeln!(code, "            )")?;
2921    writeln!(code, "        }};")?;
2922    writeln!(code)?;
2923
2924    if is_q8 {
2925        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
2926        // We load them directly into Metal buffers without dequantizing,
2927        // and use the matmul_vec_q8 shader that operates on quantized data.
2928        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
2929        writeln!(
2930            code,
2931            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
2932        )?;
2933        writeln!(
2934            code,
2935            "        // as raw bytes into a Metal buffer (no dequantization)."
2936        )?;
2937        writeln!(
2938            code,
2939            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
2940        )?;
2941        writeln!(
2942            code,
2943            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
2944        )?;
2945        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2946        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
2947        writeln!(code, "            let total_raw = rows * row_bytes;")?;
2948        writeln!(code, "            let cur = cursor.get();")?;
2949        writeln!(
2950            code,
2951            "            let data = &weights[cur..cur + total_raw];"
2952        )?;
2953        writeln!(code, "            cursor.set(cur + total_raw);")?;
2954        writeln!(code, "            device.new_buffer_with_data(")?;
2955        writeln!(code, "                data.as_ptr() as *const _,")?;
2956        writeln!(code, "                total_raw as u64,")?;
2957        writeln!(
2958            code,
2959            "                MTLResourceOptions::StorageModeShared,"
2960        )?;
2961        writeln!(code, "            )")?;
2962        writeln!(code, "        }};")?;
2963        writeln!(code)?;
2964        writeln!(
2965            code,
2966            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
2967        )?;
2968        writeln!(
2969            code,
2970            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
2971        )?;
2972        writeln!(
2973            code,
2974            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
2975        )?;
2976        writeln!(
2977            code,
2978            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
2979        )?;
2980        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2981        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
2982        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
2983        writeln!(code, "            let cur = cursor.get();")?;
2984        writeln!(
2985            code,
2986            "            let data = &weights[cur..cur + total_raw];"
2987        )?;
2988        writeln!(code, "            cursor.set(cur + total_raw);")?;
2989        writeln!(code, "            device.new_buffer_with_data(")?;
2990        writeln!(code, "                data.as_ptr() as *const _,")?;
2991        writeln!(code, "                total_raw as u64,")?;
2992        writeln!(
2993            code,
2994            "                MTLResourceOptions::StorageModeShared,"
2995        )?;
2996        writeln!(code, "            )")?;
2997        writeln!(code, "        }};")?;
2998        writeln!(code)?;
2999    }
3000
3001    if is_q4 {
3002        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3003        // We load them directly into Metal buffers without dequantizing,
3004        // and use the matmul_vec_q4 shader that operates on quantized data.
3005        // This quarters GPU memory usage vs f32 dequantization.
3006        writeln!(
3007            code,
3008            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3009        )?;
3010        writeln!(
3011            code,
3012            "        // as raw bytes into a Metal buffer (no dequantization)."
3013        )?;
3014        writeln!(
3015            code,
3016            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3017        )?;
3018        writeln!(
3019            code,
3020            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3021        )?;
3022        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3023        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3024        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3025        writeln!(code, "            let cur = cursor.get();")?;
3026        writeln!(
3027            code,
3028            "            let data = &weights[cur..cur + total_raw];"
3029        )?;
3030        writeln!(code, "            cursor.set(cur + total_raw);")?;
3031        writeln!(code, "            device.new_buffer_with_data(")?;
3032        writeln!(code, "                data.as_ptr() as *const _,")?;
3033        writeln!(code, "                total_raw as u64,")?;
3034        writeln!(
3035            code,
3036            "                MTLResourceOptions::StorageModeShared,"
3037        )?;
3038        writeln!(code, "            )")?;
3039        writeln!(code, "        }};")?;
3040        writeln!(code)?;
3041        writeln!(
3042            code,
3043            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3044        )?;
3045        writeln!(
3046            code,
3047            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3048        )?;
3049        writeln!(
3050            code,
3051            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3052        )?;
3053        writeln!(
3054            code,
3055            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3056        )?;
3057        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3058        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3059        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3060        writeln!(code, "            let cur = cursor.get();")?;
3061        writeln!(
3062            code,
3063            "            let data = &weights[cur..cur + total_raw];"
3064        )?;
3065        writeln!(code, "            cursor.set(cur + total_raw);")?;
3066        writeln!(code, "            device.new_buffer_with_data(")?;
3067        writeln!(code, "                data.as_ptr() as *const _,")?;
3068        writeln!(code, "                total_raw as u64,")?;
3069        writeln!(
3070            code,
3071            "                MTLResourceOptions::StorageModeShared,"
3072        )?;
3073        writeln!(code, "            )")?;
3074        writeln!(code, "        }};")?;
3075        writeln!(code)?;
3076    }
3077
3078    writeln!(
3079        code,
3080        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3081    )?;
3082    writeln!(code)?;
3083
3084    // Per-layer weights
3085    writeln!(
3086        code,
3087        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3088    )?;
3089    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3090
3091    // attn_norm is always f32
3092    writeln!(
3093        code,
3094        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3095    )?;
3096
3097    let qkv_rows = hidden + 2 * kv_dim;
3098    if is_q8 {
3099        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3100        writeln!(
3101            code,
3102            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3103        )?;
3104        writeln!(
3105            code,
3106            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3107        )?;
3108    } else if is_q4 {
3109        // Fused Q+K+V weight: read all three consecutive Q4_0 matrices as one buffer
3110        writeln!(
3111            code,
3112            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3113        )?;
3114        writeln!(
3115            code,
3116            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3117        )?;
3118    } else {
3119        // Fused Q+K+V weight: read all three as a single contiguous f32 buffer
3120        writeln!(
3121            code,
3122            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3123        )?;
3124        writeln!(
3125            code,
3126            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3127        )?;
3128    }
3129
3130    // ffn_norm is always f32
3131    writeln!(
3132        code,
3133        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3134    )?;
3135
3136    let gate_up_rows = 2 * intermediate;
3137    if is_q8 {
3138        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3139        writeln!(
3140            code,
3141            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3142        )?;
3143        writeln!(
3144            code,
3145            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3146        )?;
3147    } else if is_q4 {
3148        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3149        writeln!(
3150            code,
3151            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3152        )?;
3153        writeln!(
3154            code,
3155            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3156        )?;
3157    } else {
3158        // Fused gate+up weight: read both as a single contiguous f32 buffer
3159        writeln!(
3160            code,
3161            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3162        )?;
3163        writeln!(
3164            code,
3165            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3166        )?;
3167    }
3168
3169    writeln!(code, "            layers.push(LayerBuffers {{")?;
3170    writeln!(code, "                attn_norm,")?;
3171    writeln!(code, "                qkv_weight,")?;
3172    writeln!(code, "                o_weight,")?;
3173    writeln!(code, "                ffn_norm,")?;
3174    writeln!(code, "                gate_up_weight,")?;
3175    writeln!(code, "                down_weight,")?;
3176    writeln!(code, "            }});")?;
3177    writeln!(code, "        }}")?;
3178    writeln!(code)?;
3179
3180    // final_norm is always f32
3181    writeln!(
3182        code,
3183        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3184    )?;
3185    writeln!(code)?;
3186
3187    // lm_head
3188    if is_q8 {
3189        writeln!(
3190            code,
3191            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3192        )?;
3193    } else if is_q4 {
3194        writeln!(
3195            code,
3196            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3197        )?;
3198    } else {
3199        writeln!(
3200            code,
3201            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3202        )?;
3203    }
3204    writeln!(code)?;
3205
3206    // Working buffers
3207    let hidden_bytes = hidden * 4;
3208    let _kv_dim_bytes = kv_dim * 4;
3209    let intermediate_bytes = intermediate * 4;
3210    let vocab_bytes = vocab * 4;
3211    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
3212
3213    writeln!(
3214        code,
3215        "        // Allocate working buffers (shared memory for zero-copy)"
3216    )?;
3217    writeln!(
3218        code,
3219        "        let opts = MTLResourceOptions::StorageModeShared;"
3220    )?;
3221    writeln!(
3222        code,
3223        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3224    )?;
3225    writeln!(
3226        code,
3227        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3228    )?;
3229    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3230    writeln!(
3231        code,
3232        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3233    )?;
3234    writeln!(
3235        code,
3236        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3237    )?;
3238    writeln!(
3239        code,
3240        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3241    )?;
3242    writeln!(
3243        code,
3244        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3245    )?;
3246    writeln!(
3247        code,
3248        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3249    )?;
3250    let gate_up_buf_bytes = 2 * intermediate * 4;
3251    writeln!(
3252        code,
3253        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3254    )?;
3255    writeln!(
3256        code,
3257        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3258    )?;
3259    writeln!(
3260        code,
3261        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3262    )?;
3263    writeln!(
3264        code,
3265        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3266    )?;
3267    writeln!(
3268        code,
3269        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3270    )?;
3271    writeln!(
3272        code,
3273        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3274    )?;
3275    writeln!(code)?;
3276
3277    // Batch prefill working buffers
3278    let batch_hidden_bytes = hidden * 4; // per-token
3279    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3280    let batch_gate_up_bytes = 2 * intermediate * 4;
3281    let batch_intermediate_bytes = intermediate * 4;
3282    writeln!(
3283        code,
3284        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3285    )?;
3286    writeln!(
3287        code,
3288        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3289    )?;
3290    writeln!(
3291        code,
3292        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3293    )?;
3294    writeln!(
3295        code,
3296        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3297    )?;
3298    writeln!(
3299        code,
3300        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3301    )?;
3302    writeln!(
3303        code,
3304        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3305    )?;
3306    writeln!(
3307        code,
3308        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3309    )?;
3310    writeln!(
3311        code,
3312        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3313    )?;
3314    writeln!(
3315        code,
3316        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3317    )?;
3318    writeln!(
3319        code,
3320        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3321    )?;
3322    writeln!(
3323        code,
3324        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3325    )?;
3326    writeln!(code)?;
3327
3328    // KV cache buffers
3329    writeln!(code, "        // KV cache buffers (per-layer)")?;
3330    writeln!(
3331        code,
3332        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3333    )?;
3334    writeln!(
3335        code,
3336        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3337    )?;
3338    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3339    writeln!(
3340        code,
3341        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3342    )?;
3343    writeln!(
3344        code,
3345        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3346    )?;
3347    writeln!(code, "        }}")?;
3348    writeln!(code)?;
3349
3350    writeln!(code, "        Self {{")?;
3351    writeln!(code, "            device,")?;
3352    writeln!(code, "            queue,")?;
3353    writeln!(code, "            matmul_pipeline,")?;
3354    writeln!(code, "            matmul_q8_pipeline,")?;
3355    writeln!(code, "            matmul_q4_pipeline,")?;
3356    writeln!(code, "            rms_norm_pipeline,")?;
3357    writeln!(code, "            rope_pipeline,")?;
3358    writeln!(code, "            softmax_pipeline,")?;
3359    writeln!(code, "            silu_mul_pipeline,")?;
3360    writeln!(code, "            silu_mul_fused_pipeline,")?;
3361    writeln!(code, "            add_pipeline,")?;
3362    writeln!(code, "            attention_pipeline,")?;
3363    writeln!(code, "            add_inplace_pipeline,")?;
3364    writeln!(code, "            copy_pipeline,")?;
3365    writeln!(code, "            copy_offset_pipeline,")?;
3366    writeln!(code, "            matmul_batch_pipeline,")?;
3367    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3368    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3369    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3370    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3371    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3372    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3373    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3374    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3375    writeln!(code, "            rms_norm_batch_pipeline,")?;
3376    writeln!(code, "            rope_batch_pipeline,")?;
3377    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3378    writeln!(code, "            add_inplace_batch_pipeline,")?;
3379    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3380    writeln!(code, "            attention_batch_pipeline,")?;
3381    writeln!(code, "            copy_kv_batch_pipeline,")?;
3382    writeln!(code, "            rope_qk_batch_pipeline,")?;
3383    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3384    writeln!(code, "            embed_buf,")?;
3385    writeln!(code, "            layers,")?;
3386    writeln!(code, "            norm_buf,")?;
3387    writeln!(code, "            lm_head_buf,")?;
3388    writeln!(code, "            hidden_buf,")?;
3389    writeln!(code, "            residual_buf,")?;
3390    writeln!(code, "            normed_buf,")?;
3391    writeln!(code, "            qkv_buf,")?;
3392    writeln!(code, "            attn_out_buf,")?;
3393    writeln!(code, "            attn_proj_buf,")?;
3394    writeln!(code, "            gate_up_buf,")?;
3395    writeln!(code, "            ffn_hidden_buf,")?;
3396    writeln!(code, "            ffn_out_buf,")?;
3397    writeln!(code, "            add_tmp_buf,")?;
3398    writeln!(code, "            logits_buf,")?;
3399    writeln!(code, "            batch_hidden_buf,")?;
3400    writeln!(code, "            batch_residual_buf,")?;
3401    writeln!(code, "            batch_qkv_buf,")?;
3402    writeln!(code, "            batch_attn_out_buf,")?;
3403    writeln!(code, "            batch_attn_proj_buf,")?;
3404    writeln!(code, "            batch_gate_up_buf,")?;
3405    writeln!(code, "            batch_ffn_hidden_buf,")?;
3406    writeln!(code, "            batch_ffn_out_buf,")?;
3407    writeln!(code, "            batch_tokens_buf,")?;
3408    writeln!(code, "            batch_positions_buf,")?;
3409    writeln!(code, "            k_cache,")?;
3410    writeln!(code, "            v_cache,")?;
3411    writeln!(code, "            pos: 0,")?;
3412    writeln!(code, "            prev_cmd: None,")?;
3413    writeln!(code, "        }}")?;
3414    writeln!(code, "    }}")?;
3415    writeln!(code)?;
3416
3417    // ── forward() ──
3418    writeln!(
3419        code,
3420        "    /// Run the forward pass for a single token at the current position."
3421    )?;
3422    writeln!(code, "    ///")?;
3423    writeln!(
3424        code,
3425        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3426    )?;
3427    writeln!(code, "    ///")?;
3428    writeln!(
3429        code,
3430        "    /// All GPU operations are encoded into a single command buffer and"
3431    )?;
3432    writeln!(
3433        code,
3434        "    /// committed once at the end, avoiding per-operation synchronization."
3435    )?;
3436    writeln!(
3437        code,
3438        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3439    )?;
3440    writeln!(
3441        code,
3442        "        // Wait for any pending prefill command buffer"
3443    )?;
3444    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3445    writeln!(code, "            prev.wait_until_completed();")?;
3446    writeln!(code, "        }}")?;
3447    writeln!(code)?;
3448    writeln!(code, "        let pos = self.pos;")?;
3449    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3450    writeln!(code)?;
3451
3452    // Single compute encoder for the entire forward pass — no blit encoder
3453    // transitions. Copy operations use compute copy kernels instead of blits.
3454    let matmul_fn = if is_q8 {
3455        "dispatch_matmul_q8"
3456    } else if is_q4 {
3457        "dispatch_matmul_q4"
3458    } else {
3459        "dispatch_matmul"
3460    };
3461
3462    writeln!(
3463        code,
3464        "        // Single compute encoder for the entire forward pass (no blit transitions)"
3465    )?;
3466    writeln!(code, "        {{")?;
3467    writeln!(
3468        code,
3469        "            let enc = cmd.new_compute_command_encoder();"
3470    )?;
3471    writeln!(code)?;
3472
3473    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
3474    writeln!(
3475        code,
3476        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3477    )?;
3478    writeln!(
3479        code,
3480        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3481    )?;
3482    writeln!(
3483        code,
3484        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3485    )?;
3486    writeln!(
3487        code,
3488        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
3489        hidden * 4,
3490    )?;
3491    writeln!(code, "            unsafe {{")?;
3492    writeln!(
3493        code,
3494        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
3495    )?;
3496    writeln!(
3497        code,
3498        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3499    )?;
3500    writeln!(
3501        code,
3502        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
3503    )?;
3504    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
3505    writeln!(
3506        code,
3507        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3508    )?;
3509    writeln!(code, "                    hidden_ptr,")?;
3510    writeln!(code, "                    HIDDEN_SIZE,")?;
3511    writeln!(code, "                );")?;
3512    writeln!(
3513        code,
3514        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3515    )?;
3516    writeln!(code, "            }}")?;
3517    writeln!(code)?;
3518
3519    // 2. Transformer layers
3520    writeln!(code, "            // 2. Transformer layers")?;
3521    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3522    writeln!(code)?;
3523    let q_byte_offset = 0usize;
3524    let k_byte_offset = hidden * 4;
3525    let v_byte_offset = (hidden + kv_dim) * 4;
3526
3527    writeln!(
3528        code,
3529        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
3530    )?;
3531    writeln!(
3532        code,
3533        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3534    )?;
3535    writeln!(
3536        code,
3537        "                // Fused Q+K+V matmul: single dispatch for all three projections"
3538    )?;
3539    writeln!(
3540        code,
3541        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3542    )?;
3543    writeln!(
3544        code,
3545        "                // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
3546    )?;
3547    writeln!(
3548        code,
3549        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3550    )?;
3551    writeln!(
3552        code,
3553        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3554    )?;
3555    writeln!(code)?;
3556    writeln!(
3557        code,
3558        "                // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
3559    )?;
3560    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
3561    writeln!(
3562        code,
3563        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3564    )?;
3565    writeln!(
3566        code,
3567        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3568    )?;
3569    writeln!(code)?;
3570    writeln!(
3571        code,
3572        "                // Attention using Q from qkv_buf (offset 0)"
3573    )?;
3574    writeln!(
3575        code,
3576        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3577    )?;
3578    writeln!(
3579        code,
3580        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3581    )?;
3582    writeln!(
3583        code,
3584        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3585    )?;
3586    writeln!(
3587        code,
3588        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
3589    )?;
3590    writeln!(
3591        code,
3592        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3593    )?;
3594    writeln!(
3595        code,
3596        "                // Fused gate+up matmul: single dispatch for both projections"
3597    )?;
3598    writeln!(
3599        code,
3600        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3601    )?;
3602    writeln!(
3603        code,
3604        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3605    )?;
3606    writeln!(
3607        code,
3608        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3609    )?;
3610    writeln!(
3611        code,
3612        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3613    )?;
3614    writeln!(code, "            }}")?;
3615    writeln!(code)?;
3616
3617    // 3. Final RMS norm + logits
3618    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
3619    writeln!(
3620        code,
3621        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
3622    )?;
3623    writeln!(
3624        code,
3625        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3626    )?;
3627    writeln!(code)?;
3628    writeln!(code, "            enc.end_encoding();")?;
3629    writeln!(code, "        }}")?;
3630    writeln!(code)?;
3631
3632    // 5. Single commit + wait, then read back logits
3633    writeln!(
3634        code,
3635        "        // 5. Commit all GPU work and wait for completion"
3636    )?;
3637    writeln!(code, "        cmd.commit();")?;
3638    writeln!(code, "        cmd.wait_until_completed();")?;
3639    writeln!(code)?;
3640    writeln!(code, "        // 6. Read back logits from GPU")?;
3641    writeln!(code, "        let logits = unsafe {{")?;
3642    writeln!(
3643        code,
3644        "            let ptr = self.logits_buf.contents() as *const f32;"
3645    )?;
3646    writeln!(
3647        code,
3648        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3649    )?;
3650    writeln!(code, "        }};")?;
3651    writeln!(code)?;
3652    writeln!(code, "        self.pos += 1;")?;
3653    writeln!(code, "        logits")?;
3654    writeln!(code, "    }}")?;
3655    writeln!(code)?;
3656
3657    // ── forward_profile: instrumented forward with per-operation timing ──
3658    writeln!(
3659        code,
3660        "    /// Profiling forward pass that prints per-stage GPU timing."
3661    )?;
3662    writeln!(code, "    ///")?;
3663    writeln!(
3664        code,
3665        "    /// Each stage is committed and waited on separately so that GPU timestamps"
3666    )?;
3667    writeln!(
3668        code,
3669        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
3670    )?;
3671    writeln!(
3672        code,
3673        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
3674    )?;
3675    writeln!(
3676        code,
3677        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
3678    )?;
3679    writeln!(code, "        use std::time::Instant;")?;
3680    writeln!(code)?;
3681    writeln!(
3682        code,
3683        "        // Wait for any pending prefill command buffer"
3684    )?;
3685    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3686    writeln!(code, "            prev.wait_until_completed();")?;
3687    writeln!(code, "        }}")?;
3688    writeln!(code)?;
3689    writeln!(code, "        let pos = self.pos;")?;
3690    writeln!(code)?;
3691
3692    // Stage: embedding (CPU, no GPU)
3693    writeln!(
3694        code,
3695        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
3696    )?;
3697    writeln!(code, "        let t_embed = Instant::now();")?;
3698    writeln!(code, "        unsafe {{")?;
3699    writeln!(
3700        code,
3701        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
3702    )?;
3703    writeln!(
3704        code,
3705        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3706    )?;
3707    writeln!(
3708        code,
3709        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
3710    )?;
3711    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
3712    writeln!(
3713        code,
3714        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3715    )?;
3716    writeln!(code, "                hidden_ptr,")?;
3717    writeln!(code, "                HIDDEN_SIZE,")?;
3718    writeln!(code, "            );")?;
3719    writeln!(
3720        code,
3721        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3722    )?;
3723    writeln!(code, "        }}")?;
3724    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
3725    writeln!(code)?;
3726
3727    // Stage: Transformer layers (all together on GPU)
3728    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
3729    writeln!(code, "        let t_layers = Instant::now();")?;
3730    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3731    writeln!(code, "        {{")?;
3732    writeln!(
3733        code,
3734        "            let enc = cmd.new_compute_command_encoder();"
3735    )?;
3736    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3737    writeln!(
3738        code,
3739        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3740    )?;
3741    writeln!(
3742        code,
3743        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3744    )?;
3745    writeln!(
3746        code,
3747        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3748    )?;
3749    writeln!(
3750        code,
3751        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3752    )?;
3753    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
3754    writeln!(
3755        code,
3756        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3757    )?;
3758    writeln!(
3759        code,
3760        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3761    )?;
3762    writeln!(
3763        code,
3764        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3765    )?;
3766    writeln!(
3767        code,
3768        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3769    )?;
3770    writeln!(
3771        code,
3772        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3773    )?;
3774    writeln!(
3775        code,
3776        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3777    )?;
3778    writeln!(
3779        code,
3780        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3781    )?;
3782    writeln!(
3783        code,
3784        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3785    )?;
3786    writeln!(
3787        code,
3788        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3789    )?;
3790    writeln!(
3791        code,
3792        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3793    )?;
3794    writeln!(code, "            }}")?;
3795    writeln!(code, "            enc.end_encoding();")?;
3796    writeln!(code, "        }}")?;
3797    writeln!(code, "        cmd.commit();")?;
3798    writeln!(code, "        cmd.wait_until_completed();")?;
3799    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
3800    writeln!(code)?;
3801
3802    // Stage: Final norm + logits
3803    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
3804    writeln!(code, "        let t_logits = Instant::now();")?;
3805    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
3806    writeln!(code, "        {{")?;
3807    writeln!(
3808        code,
3809        "            let enc = cmd2.new_compute_command_encoder();"
3810    )?;
3811    writeln!(
3812        code,
3813        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
3814    )?;
3815    writeln!(
3816        code,
3817        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3818    )?;
3819    writeln!(code, "            enc.end_encoding();")?;
3820    writeln!(code, "        }}")?;
3821    writeln!(code, "        cmd2.commit();")?;
3822    writeln!(code, "        cmd2.wait_until_completed();")?;
3823    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
3824    writeln!(code)?;
3825
3826    // Print profile results
3827    writeln!(
3828        code,
3829        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
3830    )?;
3831    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
3832    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
3833    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
3834    writeln!(
3835        code,
3836        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
3837    )?;
3838    writeln!(code)?;
3839
3840    // Read back logits
3841    writeln!(code, "        let logits = unsafe {{")?;
3842    writeln!(
3843        code,
3844        "            let ptr = self.logits_buf.contents() as *const f32;"
3845    )?;
3846    writeln!(
3847        code,
3848        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3849    )?;
3850    writeln!(code, "        }};")?;
3851    writeln!(code)?;
3852    writeln!(code, "        self.pos += 1;")?;
3853    writeln!(code, "        logits")?;
3854    writeln!(code, "    }}")?;
3855    writeln!(code)?;
3856
3857    // ── forward_prefill: single-token async forward (backward compat) ──
3858    writeln!(
3859        code,
3860        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
3861    )?;
3862    writeln!(code, "    ///")?;
3863    writeln!(
3864        code,
3865        "    /// Commits the command buffer without waiting, enabling double-buffered"
3866    )?;
3867    writeln!(
3868        code,
3869        "    /// execution: GPU processes token N while CPU encodes token N+1."
3870    )?;
3871    writeln!(
3872        code,
3873        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
3874    )?;
3875    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
3876    writeln!(code, "    }}")?;
3877    writeln!(code)?;
3878
3879    // ── forward_prefill_batch: batched prefill for multiple tokens ──
3880    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
3881    let batch_matmul_fn = if is_q8 {
3882        "dispatch_matmul_q8_batch"
3883    } else if is_q4 {
3884        "dispatch_matmul_q4_batch"
3885    } else {
3886        "dispatch_matmul_batch"
3887    };
3888
3889    writeln!(
3890        code,
3891        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
3892    )?;
3893    writeln!(code, "    ///")?;
3894    writeln!(
3895        code,
3896        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
3897    )?;
3898    writeln!(
3899        code,
3900        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
3901    )?;
3902    writeln!(
3903        code,
3904        "    /// This provides significant speedup during prompt prefill."
3905    )?;
3906    writeln!(
3907        code,
3908        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
3909    )?;
3910    writeln!(code, "        let m = tokens.len().min(MAX_BATCH_SIZE);")?;
3911    writeln!(code, "        if m == 0 {{ return; }}")?;
3912    writeln!(code, "        let start_pos = self.pos;")?;
3913    writeln!(code)?;
3914    writeln!(code, "        // Wait for any pending command buffer")?;
3915    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3916    writeln!(code, "            prev.wait_until_completed();")?;
3917    writeln!(code, "        }}")?;
3918    writeln!(code)?;
3919
3920    // Upload token IDs and positions to GPU
3921    writeln!(
3922        code,
3923        "        // Upload token IDs and positions to GPU buffers"
3924    )?;
3925    writeln!(code, "        unsafe {{")?;
3926    writeln!(
3927        code,
3928        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
3929    )?;
3930    writeln!(
3931        code,
3932        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
3933    )?;
3934    writeln!(code, "            for i in 0..m {{")?;
3935    writeln!(code, "                *tok_ptr.add(i) = tokens[i];")?;
3936    writeln!(
3937        code,
3938        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
3939    )?;
3940    writeln!(code, "            }}")?;
3941    writeln!(code, "        }}")?;
3942    writeln!(code)?;
3943
3944    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3945    writeln!(code, "        {{")?;
3946    writeln!(
3947        code,
3948        "            let enc = cmd.new_compute_command_encoder();"
3949    )?;
3950    writeln!(code)?;
3951
3952    // 1. Batch embedding lookup
3953    writeln!(
3954        code,
3955        "            // 1. Batch embedding lookup: copy all token embeddings at once"
3956    )?;
3957    writeln!(
3958        code,
3959        "            self.dispatch_copy_embedding_batch(&enc, m);"
3960    )?;
3961    // Copy batch_hidden -> batch_residual
3962    writeln!(
3963        code,
3964        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
3965    )?;
3966    writeln!(code)?;
3967
3968    // 2. Transformer layers
3969    writeln!(code, "            // 2. Transformer layers")?;
3970    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3971    writeln!(code)?;
3972
3973    // Batch RMS norm: residual -> hidden (batched)
3974    writeln!(
3975        code,
3976        "                // Batch RMS norm: batch_residual -> batch_hidden"
3977    )?;
3978    writeln!(
3979        code,
3980        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
3981    )?;
3982
3983    // Batch QKV matmul
3984    writeln!(
3985        code,
3986        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
3987    )?;
3988    writeln!(
3989        code,
3990        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
3991    )?;
3992    writeln!(code)?;
3993
3994    // Fused RoPE on Q+K portions in a single dispatch
3995    let k_float_offset = hidden;
3996    writeln!(
3997        code,
3998        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
3999    )?;
4000    writeln!(
4001        code,
4002        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4003    )?;
4004    writeln!(code)?;
4005
4006    // Fused KV cache update: copy both K and V in a single dispatch
4007    let v_float_offset = hidden + kv_dim;
4008    writeln!(
4009        code,
4010        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4011    )?;
4012    writeln!(
4013        code,
4014        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4015    )?;
4016    writeln!(code)?;
4017
4018    // Batched causal attention: ONE dispatch for all M tokens
4019    writeln!(
4020        code,
4021        "                // Batched causal attention: one dispatch for all M tokens"
4022    )?;
4023    writeln!(
4024        code,
4025        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4026    )?;
4027    writeln!(code)?;
4028
4029    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4030    writeln!(code, "                // Batched O projection")?;
4031    writeln!(
4032        code,
4033        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4034    )?;
4035    writeln!(code)?;
4036
4037    // Batch add: residual += attn_proj for all tokens
4038    writeln!(
4039        code,
4040        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4041    )?;
4042    writeln!(code)?;
4043
4044    // Batch FFN
4045    writeln!(
4046        code,
4047        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4048    )?;
4049    writeln!(
4050        code,
4051        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4052    )?;
4053    writeln!(
4054        code,
4055        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4056    )?;
4057    writeln!(
4058        code,
4059        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4060    )?;
4061    writeln!(
4062        code,
4063        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4064    )?;
4065    writeln!(
4066        code,
4067        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4068    )?;
4069    writeln!(code, "            }}")?;
4070    writeln!(code)?;
4071
4072    // Copy last token's residual to single-token residual_buf for next forward() call
4073    writeln!(
4074        code,
4075        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4076    )?;
4077    writeln!(
4078        code,
4079        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4080    )?;
4081    writeln!(code)?;
4082    writeln!(code, "            enc.end_encoding();")?;
4083    writeln!(code, "        }}")?;
4084    writeln!(code)?;
4085
4086    writeln!(code, "        cmd.commit();")?;
4087    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4088    writeln!(code, "        self.pos += m;")?;
4089    writeln!(code, "    }}")?;
4090    writeln!(code)?;
4091
4092    // ── reset() — rewind KV cache position for new inference requests ──
4093    writeln!(
4094        code,
4095        "    /// Reset the model state for a new inference request."
4096    )?;
4097    writeln!(code, "    pub fn reset(&mut self) {{")?;
4098    writeln!(code, "        self.pos = 0;")?;
4099    writeln!(code, "        self.prev_cmd = None;")?;
4100    writeln!(code, "    }}")?;
4101    writeln!(code)?;
4102
4103    // ── Private dispatch helpers (all take a shared compute encoder) ──
4104    writeln!(
4105        code,
4106        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4107    )?;
4108    writeln!(
4109        code,
4110        "    // These methods set pipeline state + buffers + dispatch on an existing"
4111    )?;
4112    writeln!(
4113        code,
4114        "    // encoder, avoiding per-operation encoder creation overhead."
4115    )?;
4116    writeln!(code)?;
4117
4118    // dispatch_rms_norm
4119    writeln!(
4120        code,
4121        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4122    )?;
4123    writeln!(
4124        code,
4125        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4126    )?;
4127    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4128    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4129    writeln!(
4130        code,
4131        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4132    )?;
4133    writeln!(
4134        code,
4135        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4136    )?;
4137    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4138    writeln!(
4139        code,
4140        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4141    )?;
4142    writeln!(
4143        code,
4144        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4145    )?;
4146    writeln!(
4147        code,
4148        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4149    )?;
4150    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4151    writeln!(
4152        code,
4153        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4154    )?;
4155    writeln!(
4156        code,
4157        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4158    )?;
4159    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4160    writeln!(code, "    }}")?;
4161    writeln!(code)?;
4162
4163    // dispatch_matmul
4164    writeln!(
4165        code,
4166        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4167    )?;
4168    writeln!(
4169        code,
4170        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4171    )?;
4172    writeln!(code, "        let r: u32 = rows as u32;")?;
4173    writeln!(code, "        let c: u32 = cols as u32;")?;
4174    writeln!(
4175        code,
4176        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4177    )?;
4178    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4179    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4180    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4181    writeln!(
4182        code,
4183        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4184    )?;
4185    writeln!(
4186        code,
4187        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4188    )?;
4189    writeln!(
4190        code,
4191        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4192    )?;
4193    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4194    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4195    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4196    writeln!(
4197        code,
4198        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4199    )?;
4200    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4201    writeln!(code, "    }}")?;
4202    writeln!(code)?;
4203
4204    // dispatch_matmul_q8
4205    writeln!(
4206        code,
4207        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4208    )?;
4209    writeln!(
4210        code,
4211        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4212    )?;
4213    writeln!(
4214        code,
4215        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4216    )?;
4217    writeln!(code, "        let r: u32 = rows as u32;")?;
4218    writeln!(code, "        let c: u32 = cols as u32;")?;
4219    writeln!(
4220        code,
4221        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4222    )?;
4223    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4224    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4225    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4226    writeln!(
4227        code,
4228        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4229    )?;
4230    writeln!(
4231        code,
4232        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4233    )?;
4234    writeln!(
4235        code,
4236        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4237    )?;
4238    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4239    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4240    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4241    writeln!(
4242        code,
4243        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4244    )?;
4245    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4246    writeln!(code, "    }}")?;
4247    writeln!(code)?;
4248
4249    // dispatch_matmul_q4
4250    writeln!(
4251        code,
4252        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4253    )?;
4254    writeln!(
4255        code,
4256        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4257    )?;
4258    writeln!(
4259        code,
4260        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4261    )?;
4262    writeln!(code, "        let r: u32 = rows as u32;")?;
4263    writeln!(code, "        let c: u32 = cols as u32;")?;
4264    writeln!(
4265        code,
4266        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4267    )?;
4268    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4269    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4270    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4271    writeln!(
4272        code,
4273        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4274    )?;
4275    writeln!(
4276        code,
4277        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4278    )?;
4279    writeln!(
4280        code,
4281        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4282    )?;
4283    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4284    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4285    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4286    writeln!(
4287        code,
4288        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4289    )?;
4290    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4291    writeln!(code, "    }}")?;
4292    writeln!(code)?;
4293
4294    // dispatch_rope
4295    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4296    writeln!(
4297        code,
4298        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4299    )?;
4300    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4301    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4302    writeln!(code, "        let p: u32 = pos as u32;")?;
4303    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4304    writeln!(
4305        code,
4306        "        let total_pairs = num_heads * (head_dim / 2);"
4307    )?;
4308    writeln!(
4309        code,
4310        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4311    )?;
4312    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4313    writeln!(
4314        code,
4315        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4316    )?;
4317    writeln!(
4318        code,
4319        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4320    )?;
4321    writeln!(
4322        code,
4323        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4324    )?;
4325    writeln!(
4326        code,
4327        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4328    )?;
4329    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4330    writeln!(
4331        code,
4332        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4333    )?;
4334    writeln!(
4335        code,
4336        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4337    )?;
4338    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4339    writeln!(code, "    }}")?;
4340    writeln!(code)?;
4341
4342    // dispatch_rope_offset
4343    writeln!(
4344        code,
4345        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4346    )?;
4347    writeln!(
4348        code,
4349        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4350    )?;
4351    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4352    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4353    writeln!(code, "        let p: u32 = pos as u32;")?;
4354    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4355    writeln!(
4356        code,
4357        "        let total_pairs = num_heads * (head_dim / 2);"
4358    )?;
4359    writeln!(
4360        code,
4361        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4362    )?;
4363    writeln!(
4364        code,
4365        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4366    )?;
4367    writeln!(
4368        code,
4369        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4370    )?;
4371    writeln!(
4372        code,
4373        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4374    )?;
4375    writeln!(
4376        code,
4377        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4378    )?;
4379    writeln!(
4380        code,
4381        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4382    )?;
4383    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4384    writeln!(
4385        code,
4386        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4387    )?;
4388    writeln!(
4389        code,
4390        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4391    )?;
4392    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4393    writeln!(code, "    }}")?;
4394    writeln!(code)?;
4395
4396    // dispatch_attention
4397    writeln!(
4398        code,
4399        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4400    )?;
4401    writeln!(
4402        code,
4403        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4404    )?;
4405    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4406    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4407    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4408    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4409    writeln!(
4410        code,
4411        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4412    )?;
4413    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4414    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4415    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4416    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4417    writeln!(
4418        code,
4419        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4420    )?;
4421    writeln!(
4422        code,
4423        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4424    )?;
4425    writeln!(
4426        code,
4427        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4428    )?;
4429    writeln!(
4430        code,
4431        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4432    )?;
4433    writeln!(
4434        code,
4435        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4436    )?;
4437    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4438    writeln!(
4439        code,
4440        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4441    )?;
4442    writeln!(
4443        code,
4444        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4445    )?;
4446    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4447    writeln!(code, "    }}")?;
4448    writeln!(code)?;
4449
4450    // dispatch_attention_offset
4451    writeln!(
4452        code,
4453        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4454    )?;
4455    writeln!(
4456        code,
4457        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4458    )?;
4459    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4460    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4461    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4462    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4463    writeln!(
4464        code,
4465        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4466    )?;
4467    writeln!(
4468        code,
4469        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
4470    )?;
4471    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4472    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4473    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4474    writeln!(
4475        code,
4476        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4477    )?;
4478    writeln!(
4479        code,
4480        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4481    )?;
4482    writeln!(
4483        code,
4484        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4485    )?;
4486    writeln!(
4487        code,
4488        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4489    )?;
4490    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4491    writeln!(
4492        code,
4493        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4494    )?;
4495    writeln!(
4496        code,
4497        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4498    )?;
4499    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4500    writeln!(code, "    }}")?;
4501    writeln!(code)?;
4502
4503    // dispatch_silu_mul
4504    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
4505    writeln!(
4506        code,
4507        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
4508    )?;
4509    writeln!(code, "        let count: u32 = n as u32;")?;
4510    writeln!(
4511        code,
4512        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
4513    )?;
4514    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
4515    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
4516    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4517    writeln!(
4518        code,
4519        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4520    )?;
4521    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4522    writeln!(
4523        code,
4524        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4525    )?;
4526    writeln!(
4527        code,
4528        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4529    )?;
4530    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4531    writeln!(code, "    }}")?;
4532    writeln!(code)?;
4533
4534    // dispatch_silu_mul_fused
4535    writeln!(
4536        code,
4537        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
4538    )?;
4539    writeln!(
4540        code,
4541        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
4542    )?;
4543    writeln!(
4544        code,
4545        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
4546    )?;
4547    writeln!(code, "        let count: u32 = n as u32;")?;
4548    writeln!(
4549        code,
4550        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
4551    )?;
4552    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
4553    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
4554    writeln!(
4555        code,
4556        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4557    )?;
4558    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4559    writeln!(
4560        code,
4561        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4562    )?;
4563    writeln!(
4564        code,
4565        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4566    )?;
4567    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4568    writeln!(code, "    }}")?;
4569    writeln!(code)?;
4570
4571    // dispatch_copy (simple src -> dst copy via compute kernel)
4572    writeln!(
4573        code,
4574        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
4575    )?;
4576    writeln!(
4577        code,
4578        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
4579    )?;
4580    writeln!(code, "        let n: u32 = count as u32;")?;
4581    writeln!(
4582        code,
4583        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4584    )?;
4585    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4586    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4587    writeln!(
4588        code,
4589        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4590    )?;
4591    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4592    writeln!(
4593        code,
4594        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4595    )?;
4596    writeln!(
4597        code,
4598        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4599    )?;
4600    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4601    writeln!(code, "    }}")?;
4602    writeln!(code)?;
4603
4604    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
4605    writeln!(
4606        code,
4607        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
4608    )?;
4609    writeln!(
4610        code,
4611        "    /// Used for embedding table lookup (copy a specific row)."
4612    )?;
4613    writeln!(
4614        code,
4615        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
4616    )?;
4617    writeln!(code, "        let off: u32 = src_offset as u32;")?;
4618    writeln!(code, "        let n: u32 = count as u32;")?;
4619    writeln!(
4620        code,
4621        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
4622    )?;
4623    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4624    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4625    writeln!(
4626        code,
4627        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
4628    )?;
4629    writeln!(
4630        code,
4631        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4632    )?;
4633    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4634    writeln!(
4635        code,
4636        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4637    )?;
4638    writeln!(
4639        code,
4640        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4641    )?;
4642    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4643    writeln!(code, "    }}")?;
4644    writeln!(code)?;
4645
4646    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
4647    writeln!(
4648        code,
4649        "    /// Dispatch copy from source at byte offset to destination at float offset."
4650    )?;
4651    writeln!(
4652        code,
4653        "    /// Used for KV cache updates from fused QKV buffer."
4654    )?;
4655    writeln!(
4656        code,
4657        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4658    )?;
4659    writeln!(code, "        let n: u32 = count as u32;")?;
4660    writeln!(
4661        code,
4662        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4663    )?;
4664    writeln!(
4665        code,
4666        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
4667    )?;
4668    writeln!(
4669        code,
4670        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4671    )?;
4672    writeln!(
4673        code,
4674        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4675    )?;
4676    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4677    writeln!(
4678        code,
4679        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4680    )?;
4681    writeln!(
4682        code,
4683        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4684    )?;
4685    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4686    writeln!(code, "    }}")?;
4687    writeln!(code)?;
4688
4689    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
4690    writeln!(
4691        code,
4692        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
4693    )?;
4694    writeln!(
4695        code,
4696        "    /// Used for KV cache updates (write to a specific position in the cache)."
4697    )?;
4698    writeln!(
4699        code,
4700        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
4701    )?;
4702    writeln!(code, "        let n: u32 = count as u32;")?;
4703    writeln!(
4704        code,
4705        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4706    )?;
4707    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4708    writeln!(
4709        code,
4710        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
4711    )?;
4712    writeln!(
4713        code,
4714        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4715    )?;
4716    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4717    writeln!(
4718        code,
4719        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4720    )?;
4721    writeln!(
4722        code,
4723        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4724    )?;
4725    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4726    writeln!(code, "    }}")?;
4727    writeln!(code)?;
4728
4729    // dispatch_add_inplace (residual connection, no blit needed)
4730    writeln!(
4731        code,
4732        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
4733    )?;
4734    writeln!(
4735        code,
4736        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
4737    )?;
4738    writeln!(code, "        let count: u32 = n as u32;")?;
4739    writeln!(
4740        code,
4741        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
4742    )?;
4743    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
4744    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
4745    writeln!(
4746        code,
4747        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4748    )?;
4749    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4750    writeln!(
4751        code,
4752        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4753    )?;
4754    writeln!(
4755        code,
4756        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4757    )?;
4758    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4759    writeln!(code, "    }}")?;
4760    writeln!(code)?;
4761
4762    // ── Batched prefill dispatch helpers ──
4763    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
4764    writeln!(code)?;
4765
4766    // dispatch_copy_embedding_batch
4767    writeln!(
4768        code,
4769        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
4770    )?;
4771    writeln!(
4772        code,
4773        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
4774    )?;
4775    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
4776    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4777    writeln!(
4778        code,
4779        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
4780    )?;
4781    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
4782    writeln!(
4783        code,
4784        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
4785    )?;
4786    writeln!(
4787        code,
4788        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
4789    )?;
4790    writeln!(
4791        code,
4792        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
4793    )?;
4794    writeln!(
4795        code,
4796        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4797    )?;
4798    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
4799    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4800    writeln!(
4801        code,
4802        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4803    )?;
4804    writeln!(
4805        code,
4806        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4807    )?;
4808    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4809    writeln!(code, "    }}")?;
4810    writeln!(code)?;
4811
4812    // dispatch_rms_norm_batch
4813    writeln!(
4814        code,
4815        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
4816    )?;
4817    writeln!(
4818        code,
4819        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
4820    )?;
4821    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4822    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4823    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4824    writeln!(
4825        code,
4826        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
4827    )?;
4828    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
4829    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4830    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4831    writeln!(
4832        code,
4833        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4834    )?;
4835    writeln!(
4836        code,
4837        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4838    )?;
4839    writeln!(
4840        code,
4841        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4842    )?;
4843    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4844    writeln!(
4845        code,
4846        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
4847    )?;
4848    writeln!(
4849        code,
4850        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4851    )?;
4852    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4853    writeln!(code, "    }}")?;
4854    writeln!(code)?;
4855
4856    // dispatch_matmul_batch (f32)
4857    writeln!(
4858        code,
4859        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4860    )?;
4861    writeln!(
4862        code,
4863        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4864    )?;
4865    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4866    writeln!(code, "        let r: u32 = rows as u32;")?;
4867    writeln!(code, "        let c: u32 = cols as u32;")?;
4868    writeln!(
4869        code,
4870        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
4871    )?;
4872    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4873    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4874    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4875    writeln!(
4876        code,
4877        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4878    )?;
4879    writeln!(
4880        code,
4881        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4882    )?;
4883    writeln!(
4884        code,
4885        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4886    )?;
4887    writeln!(
4888        code,
4889        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
4890    )?;
4891    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
4892    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4893    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4894    writeln!(
4895        code,
4896        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4897    )?;
4898    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4899    writeln!(code, "    }}")?;
4900    writeln!(code)?;
4901
4902    // dispatch_matmul_q8_batch
4903    writeln!(
4904        code,
4905        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4906    )?;
4907    writeln!(code, "    ///")?;
4908    writeln!(
4909        code,
4910        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
4911    )?;
4912    writeln!(
4913        code,
4914        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
4915    )?;
4916    writeln!(
4917        code,
4918        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4919    )?;
4920    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4921    writeln!(code, "        let r: u32 = rows as u32;")?;
4922    writeln!(code, "        let c: u32 = cols as u32;")?;
4923    writeln!(
4924        code,
4925        "        // Tile sizes must match the Metal shader constants."
4926    )?;
4927    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
4928    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
4929    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
4930    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
4931    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
4932    writeln!(
4933        code,
4934        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
4935    )?;
4936    writeln!(
4937        code,
4938        "        // Prefer the large 32×32 tile when the problem supports it — halves"
4939    )?;
4940    writeln!(
4941        code,
4942        "        // dispatch count and reuses each weight load across 32 tokens."
4943    )?;
4944    writeln!(
4945        code,
4946        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
4947    )?;
4948    writeln!(
4949        code,
4950        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
4951    )?;
4952    writeln!(
4953        code,
4954        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
4955    )?;
4956    writeln!(
4957        code,
4958        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
4959    )?;
4960    writeln!(
4961        code,
4962        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
4963    )?;
4964    writeln!(
4965        code,
4966        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
4967    )?;
4968    writeln!(
4969        code,
4970        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
4971    )?;
4972    writeln!(
4973        code,
4974        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
4975    )?;
4976    writeln!(
4977        code,
4978        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
4979    )?;
4980    writeln!(code, "            let use_h4 = cols >= 2048;")?;
4981    writeln!(code, "            let pipe = if use_h4 {{")?;
4982    writeln!(code, "                if num_tokens >= 256 {{")?;
4983    writeln!(code, "                    &self.matmul_q8_mma32_hh4_pipeline")?;
4984    writeln!(code, "                }} else {{")?;
4985    writeln!(code, "                    &self.matmul_q8_mma32_h4_pipeline")?;
4986    writeln!(code, "                }}")?;
4987    writeln!(code, "            }} else {{")?;
4988    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
4989    writeln!(code, "            }};")?;
4990    writeln!(
4991        code,
4992        "            enc.set_compute_pipeline_state(pipe);"
4993    )?;
4994    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
4995    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
4996    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
4997    writeln!(
4998        code,
4999        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5000    )?;
5001    writeln!(
5002        code,
5003        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5004    )?;
5005    writeln!(
5006        code,
5007        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5008    )?;
5009    writeln!(
5010        code,
5011        "            let row_tgs = rows / MMA32_ROW_TILE;"
5012    )?;
5013    writeln!(
5014        code,
5015        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5016    )?;
5017    writeln!(
5018        code,
5019        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5020    )?;
5021    writeln!(
5022        code,
5023        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5024    )?;
5025    writeln!(
5026        code,
5027        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5028    )?;
5029    writeln!(
5030        code,
5031        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5032    )?;
5033    writeln!(
5034        code,
5035        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5036    )?;
5037    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5038    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5039    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5040    writeln!(
5041        code,
5042        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5043    )?;
5044    writeln!(
5045        code,
5046        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5047    )?;
5048    writeln!(
5049        code,
5050        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5051    )?;
5052    writeln!(
5053        code,
5054        "            let row_tgs = rows / MMA_ROW_TILE;"
5055    )?;
5056    writeln!(
5057        code,
5058        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5059    )?;
5060    writeln!(
5061        code,
5062        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5063    )?;
5064    writeln!(
5065        code,
5066        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5067    )?;
5068    writeln!(
5069        code,
5070        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5071    )?;
5072    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5073    writeln!(
5074        code,
5075        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5076    )?;
5077    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5078    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5079    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5080    writeln!(
5081        code,
5082        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5083    )?;
5084    writeln!(
5085        code,
5086        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5087    )?;
5088    writeln!(
5089        code,
5090        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5091    )?;
5092    writeln!(
5093        code,
5094        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5095    )?;
5096    writeln!(
5097        code,
5098        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5099    )?;
5100    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5101    writeln!(
5102        code,
5103        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5104    )?;
5105    writeln!(
5106        code,
5107        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5108    )?;
5109    writeln!(code, "        }} else {{")?;
5110    writeln!(
5111        code,
5112        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5113    )?;
5114    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5115    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5116    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5117    writeln!(
5118        code,
5119        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5120    )?;
5121    writeln!(
5122        code,
5123        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5124    )?;
5125    writeln!(
5126        code,
5127        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5128    )?;
5129    writeln!(
5130        code,
5131        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5132    )?;
5133    writeln!(
5134        code,
5135        "            let num_tg = (row_tgs * num_tokens) as u64;"
5136    )?;
5137    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5138    writeln!(
5139        code,
5140        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5141    )?;
5142    writeln!(
5143        code,
5144        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5145    )?;
5146    writeln!(code, "        }}")?;
5147    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5148    writeln!(code, "    }}")?;
5149    writeln!(code)?;
5150
5151    // dispatch_matmul_q4_batch
5152    writeln!(
5153        code,
5154        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5155    )?;
5156    writeln!(
5157        code,
5158        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5159    )?;
5160    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5161    writeln!(code, "        let r: u32 = rows as u32;")?;
5162    writeln!(code, "        let c: u32 = cols as u32;")?;
5163    writeln!(
5164        code,
5165        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5166    )?;
5167    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5168    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5169    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5170    writeln!(
5171        code,
5172        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5173    )?;
5174    writeln!(
5175        code,
5176        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5177    )?;
5178    writeln!(
5179        code,
5180        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5181    )?;
5182    writeln!(
5183        code,
5184        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5185    )?;
5186    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5187    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5188    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5189    writeln!(
5190        code,
5191        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5192    )?;
5193    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5194    writeln!(code, "    }}")?;
5195    writeln!(code)?;
5196
5197    // dispatch_rope_batch
5198    writeln!(
5199        code,
5200        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5201    )?;
5202    writeln!(
5203        code,
5204        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5205    )?;
5206    writeln!(
5207        code,
5208        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5209    )?;
5210    writeln!(
5211        code,
5212        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5213    )?;
5214    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5215    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5216    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5217    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5218    writeln!(
5219        code,
5220        "        let pairs_per_token = num_heads * (head_dim / 2);"
5221    )?;
5222    writeln!(
5223        code,
5224        "        let total_pairs = num_tokens * pairs_per_token;"
5225    )?;
5226    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5227    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5228    // we need to pass the buffer at the right byte offset for each token's data.
5229    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5230    // but our layout is data[token * row_stride + data_float_offset + ...].
5231    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5232    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5233    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5234    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5235    //
5236    // Simplest approach: use the single-token rope kernel for each token in a loop.
5237    // This is still efficient because we're dispatching all within the same command encoder.
5238    writeln!(
5239        code,
5240        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5241    )?;
5242    writeln!(code, "        for t in 0..num_tokens {{")?;
5243    writeln!(
5244        code,
5245        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5246    )?;
5247    writeln!(
5248        code,
5249        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5250    )?;
5251    writeln!(
5252        code,
5253        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5254    )?;
5255    writeln!(
5256        code,
5257        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5258    )?;
5259    writeln!(
5260        code,
5261        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5262    )?;
5263    writeln!(
5264        code,
5265        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5266    )?;
5267    writeln!(
5268        code,
5269        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5270    )?;
5271    writeln!(
5272        code,
5273        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5274    )?;
5275    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5276    writeln!(
5277        code,
5278        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5279    )?;
5280    writeln!(
5281        code,
5282        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5283    )?;
5284    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5285    writeln!(code, "        }}")?;
5286    writeln!(code, "    }}")?;
5287    writeln!(code)?;
5288
5289    // dispatch_silu_mul_fused_batch
5290    writeln!(
5291        code,
5292        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5293    )?;
5294    writeln!(
5295        code,
5296        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5297    )?;
5298    writeln!(code, "        let count: u32 = n as u32;")?;
5299    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5300    writeln!(
5301        code,
5302        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5303    )?;
5304    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5305    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5306    writeln!(
5307        code,
5308        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5309    )?;
5310    writeln!(
5311        code,
5312        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5313    )?;
5314    writeln!(code, "        let total = num_tokens * n;")?;
5315    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5316    writeln!(
5317        code,
5318        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5319    )?;
5320    writeln!(
5321        code,
5322        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5323    )?;
5324    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5325    writeln!(code, "    }}")?;
5326    writeln!(code)?;
5327
5328    // dispatch_add_inplace_batch_n (add n elements in-place)
5329    writeln!(
5330        code,
5331        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5332    )?;
5333    writeln!(
5334        code,
5335        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5336    )?;
5337    writeln!(code, "        let count: u32 = total_n as u32;")?;
5338    writeln!(
5339        code,
5340        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5341    )?;
5342    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5343    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5344    writeln!(
5345        code,
5346        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5347    )?;
5348    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5349    writeln!(
5350        code,
5351        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5352    )?;
5353    writeln!(
5354        code,
5355        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5356    )?;
5357    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5358    writeln!(code, "    }}")?;
5359    writeln!(code)?;
5360
5361    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
5362    writeln!(
5363        code,
5364        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
5365    )?;
5366    writeln!(
5367        code,
5368        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5369    )?;
5370    writeln!(code, "        let n: u32 = count as u32;")?;
5371    writeln!(
5372        code,
5373        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5374    )?;
5375    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5376    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5377    writeln!(
5378        code,
5379        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5380    )?;
5381    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5382    writeln!(
5383        code,
5384        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5385    )?;
5386    writeln!(
5387        code,
5388        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5389    )?;
5390    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5391    writeln!(code, "    }}")?;
5392    writeln!(code)?;
5393
5394    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
5395    writeln!(
5396        code,
5397        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
5398    )?;
5399    writeln!(
5400        code,
5401        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5402    )?;
5403    writeln!(code, "        let n: u32 = count as u32;")?;
5404    writeln!(
5405        code,
5406        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5407    )?;
5408    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5409    writeln!(
5410        code,
5411        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5412    )?;
5413    writeln!(
5414        code,
5415        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5416    )?;
5417    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5418    writeln!(
5419        code,
5420        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5421    )?;
5422    writeln!(
5423        code,
5424        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5425    )?;
5426    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5427    writeln!(code, "    }}")?;
5428    writeln!(code)?;
5429
5430    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
5431    writeln!(
5432        code,
5433        "    /// Copy from src at byte offset to dst at float offset."
5434    )?;
5435    writeln!(
5436        code,
5437        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5438    )?;
5439    writeln!(code, "        let n: u32 = count as u32;")?;
5440    writeln!(
5441        code,
5442        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5443    )?;
5444    writeln!(
5445        code,
5446        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5447    )?;
5448    writeln!(
5449        code,
5450        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5451    )?;
5452    writeln!(
5453        code,
5454        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5455    )?;
5456    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5457    writeln!(
5458        code,
5459        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5460    )?;
5461    writeln!(
5462        code,
5463        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5464    )?;
5465    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5466    writeln!(code, "    }}")?;
5467    writeln!(code)?;
5468
5469    // dispatch_copy_kv_batch
5470    writeln!(
5471        code,
5472        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
5473    )?;
5474    writeln!(
5475        code,
5476        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
5477    )?;
5478    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5479    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
5480    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5481    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
5482    writeln!(code, "        let so: u32 = src_offset as u32;")?;
5483    writeln!(
5484        code,
5485        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
5486    )?;
5487    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5488    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5489    writeln!(
5490        code,
5491        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5492    )?;
5493    writeln!(
5494        code,
5495        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
5496    )?;
5497    writeln!(
5498        code,
5499        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5500    )?;
5501    writeln!(
5502        code,
5503        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
5504    )?;
5505    writeln!(
5506        code,
5507        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
5508    )?;
5509    writeln!(code, "        let total = num_tokens * kv_dim;")?;
5510    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5511    writeln!(
5512        code,
5513        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5514    )?;
5515    writeln!(
5516        code,
5517        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5518    )?;
5519    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5520    writeln!(code, "    }}")?;
5521    writeln!(code)?;
5522
5523    // dispatch_attention_batch
5524    writeln!(
5525        code,
5526        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
5527    )?;
5528    writeln!(
5529        code,
5530        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
5531    )?;
5532    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5533    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5534    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5535    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5536    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5537    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
5538    writeln!(
5539        code,
5540        "        enc.set_compute_pipeline_state(&self.attention_batch_pipeline);"
5541    )?;
5542    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
5543    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5544    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5545    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5546    writeln!(
5547        code,
5548        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5549    )?;
5550    writeln!(
5551        code,
5552        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5553    )?;
5554    writeln!(
5555        code,
5556        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5557    )?;
5558    writeln!(
5559        code,
5560        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5561    )?;
5562    writeln!(
5563        code,
5564        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5565    )?;
5566    writeln!(
5567        code,
5568        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5569    )?;
5570    writeln!(
5571        code,
5572        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
5573    )?;
5574    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5575    writeln!(
5576        code,
5577        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
5578    )?;
5579    writeln!(
5580        code,
5581        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5582    )?;
5583    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5584    writeln!(code, "    }}")?;
5585    writeln!(code)?;
5586
5587    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
5588    writeln!(
5589        code,
5590        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
5591    )?;
5592    writeln!(
5593        code,
5594        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
5595    )?;
5596    writeln!(
5597        code,
5598        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
5599    )?;
5600    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5601    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5602    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
5603    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5604    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5605    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
5606    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5607    writeln!(
5608        code,
5609        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
5610    )?;
5611    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
5612    writeln!(
5613        code,
5614        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5615    )?;
5616    writeln!(
5617        code,
5618        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5619    )?;
5620    writeln!(
5621        code,
5622        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
5623    )?;
5624    writeln!(
5625        code,
5626        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5627    )?;
5628    writeln!(
5629        code,
5630        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5631    )?;
5632    writeln!(
5633        code,
5634        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5635    )?;
5636    writeln!(
5637        code,
5638        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5639    )?;
5640    writeln!(
5641        code,
5642        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
5643    )?;
5644    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5645    writeln!(
5646        code,
5647        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5648    )?;
5649    writeln!(
5650        code,
5651        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5652    )?;
5653    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5654    writeln!(code, "    }}")?;
5655    writeln!(code)?;
5656
5657    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
5658    writeln!(
5659        code,
5660        "    /// Dispatch fused K+V cache copy in one kernel launch."
5661    )?;
5662    writeln!(
5663        code,
5664        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
5665    )?;
5666    writeln!(
5667        code,
5668        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
5669    )?;
5670    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5671    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
5672    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5673    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
5674    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
5675    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
5676    writeln!(
5677        code,
5678        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
5679    )?;
5680    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5681    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
5682    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
5683    writeln!(
5684        code,
5685        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5686    )?;
5687    writeln!(
5688        code,
5689        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
5690    )?;
5691    writeln!(
5692        code,
5693        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5694    )?;
5695    writeln!(
5696        code,
5697        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
5698    )?;
5699    writeln!(
5700        code,
5701        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
5702    )?;
5703    writeln!(
5704        code,
5705        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
5706    )?;
5707    writeln!(
5708        code,
5709        "        let total = num_tokens * kv_dim * 2;  // K + V"
5710    )?;
5711    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5712    writeln!(
5713        code,
5714        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5715    )?;
5716    writeln!(
5717        code,
5718        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5719    )?;
5720    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5721    writeln!(code, "    }}")?;
5722
5723    writeln!(code, "}}")?;
5724    writeln!(code)?;
5725
5726    Ok(())
5727}
5728
5729fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
5730    writeln!(
5731        code,
5732        "// ── Helper functions ──────────────────────────────────"
5733    )?;
5734    writeln!(code)?;
5735    writeln!(
5736        code,
5737        "/// Create a compute pipeline from a named function in the Metal library."
5738    )?;
5739    writeln!(
5740        code,
5741        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
5742    )?;
5743    writeln!(
5744        code,
5745        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
5746    )?;
5747    writeln!(
5748        code,
5749        "    device.new_compute_pipeline_state_with_function(&func)"
5750    )?;
5751    writeln!(
5752        code,
5753        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
5754    )?;
5755    writeln!(code, "}}")?;
5756    writeln!(code)?;
5757
5758    Ok(())
5759}
5760
5761// ---------------------------------------------------------------------------
5762// main.rs generation
5763// ---------------------------------------------------------------------------
5764
5765fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
5766    let _sanitized = sanitize_name(model_name);
5767    let _vocab = config.vocab_size;
5768
5769    let mut code = String::with_capacity(16 * 1024);
5770    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
5771    writeln!(
5772        code,
5773        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
5774    )?;
5775    writeln!(code)?;
5776    writeln!(code, "mod model;")?;
5777    writeln!(code)?;
5778    writeln!(code, "use std::io::Write;")?;
5779    writeln!(code, "use std::time::Instant;")?;
5780    writeln!(code, "use serde::Deserialize;")?;
5781    writeln!(code)?;
5782
5783    // -- main function --
5784    writeln!(code, "fn main() {{")?;
5785    writeln!(
5786        code,
5787        "    let args: Vec<String> = std::env::args().collect();"
5788    )?;
5789    writeln!(code)?;
5790    writeln!(
5791        code,
5792        "    // Detect --serve mode (only requires weights + tokenizer)"
5793    )?;
5794    writeln!(
5795        code,
5796        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
5797    )?;
5798    writeln!(code)?;
5799    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
5800    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
5801    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
5802    writeln!(code, "        std::process::exit(1);")?;
5803    writeln!(code, "    }}")?;
5804    writeln!(code)?;
5805    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
5806    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
5807    writeln!(code, "        std::process::exit(1);")?;
5808    writeln!(code, "    }}")?;
5809    writeln!(code)?;
5810    writeln!(code, "    let weights_path = &args[1];")?;
5811    writeln!(code, "    let tokenizer_path = &args[2];")?;
5812    writeln!(code)?;
5813    writeln!(code, "    // Parse optional flags")?;
5814    writeln!(code, "    let mut max_tokens: usize = 128;")?;
5815    writeln!(code, "    let mut port: u16 = 8080;")?;
5816    writeln!(
5817        code,
5818        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
5819    )?;
5820    writeln!(
5821        code,
5822        "    let profile = args.iter().any(|a| a == \"--profile\");"
5823    )?;
5824    writeln!(code, "    let mut i = 3;")?;
5825    writeln!(code, "    while i < args.len() {{")?;
5826    writeln!(
5827        code,
5828        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
5829    )?;
5830    writeln!(
5831        code,
5832        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
5833    )?;
5834    writeln!(code, "            i += 2;")?;
5835    writeln!(
5836        code,
5837        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
5838    )?;
5839    writeln!(
5840        code,
5841        "            port = args[i + 1].parse().unwrap_or(8080);"
5842    )?;
5843    writeln!(code, "            i += 2;")?;
5844    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
5845    writeln!(code, "            i += 1;")?;
5846    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
5847    writeln!(code, "            i += 1;")?;
5848    writeln!(code, "        }} else {{")?;
5849    writeln!(code, "            i += 1;")?;
5850    writeln!(code, "        }}")?;
5851    writeln!(code, "    }}")?;
5852    writeln!(code)?;
5853
5854    // -- load model (shared by both modes) --
5855    writeln!(
5856        code,
5857        "    // Memory-map weights for zero-copy loading on Apple Silicon"
5858    )?;
5859    writeln!(
5860        code,
5861        "    let weights_file = std::fs::File::open(weights_path)"
5862    )?;
5863    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
5864    writeln!(
5865        code,
5866        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
5867    )?;
5868    writeln!(code)?;
5869    writeln!(code, "    // Load tokenizer")?;
5870    writeln!(
5871        code,
5872        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
5873    )?;
5874    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
5875    writeln!(code)?;
5876    writeln!(code, "    // Create Metal model")?;
5877    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
5878    writeln!(
5879        code,
5880        "    let mut model = model::MetalModel::new(&weights_mmap);"
5881    )?;
5882    writeln!(code)?;
5883
5884    // -- branch: serve vs CLI --
5885    writeln!(code, "    if serve_mode {{")?;
5886    writeln!(code, "        serve(model, tokenizer, port);")?;
5887    writeln!(code, "    }} else {{")?;
5888    writeln!(code, "        let prompt = &args[3];")?;
5889    writeln!(
5890        code,
5891        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
5892    )?;
5893    writeln!(code, "    }}")?;
5894    writeln!(code, "}}")?;
5895    writeln!(code)?;
5896
5897    // -- cli_mode function --
5898    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
5899    writeln!(code, "    // Tokenize prompt")?;
5900    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
5901    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
5902    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
5903    writeln!(code)?;
5904    writeln!(
5905        code,
5906        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
5907    )?;
5908    writeln!(
5909        code,
5910        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
5911    )?;
5912    writeln!(
5913        code,
5914        "    // The last token uses synchronous forward() to get logits."
5915    )?;
5916    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
5917    writeln!(code, "    let prefill_start = Instant::now();")?;
5918    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
5919    writeln!(
5920        code,
5921        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
5922    )?;
5923    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
5924    writeln!(code, "    }} else {{")?;
5925    writeln!(code, "        model.forward(prompt_tokens[0])")?;
5926    writeln!(code, "    }};")?;
5927    writeln!(
5928        code,
5929        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
5930    )?;
5931    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
5932    writeln!(
5933        code,
5934        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
5935    )?;
5936    writeln!(
5937        code,
5938        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
5939    )?;
5940    writeln!(code)?;
5941    writeln!(code, "    // Generate tokens")?;
5942    writeln!(code, "    let mut next_token = argmax(&logits);")?;
5943    writeln!(code, "    let gen_start = Instant::now();")?;
5944    writeln!(code, "    let mut generated_count: usize = 0;")?;
5945    writeln!(code)?;
5946    writeln!(code, "    for _ in 0..max_tokens {{")?;
5947    writeln!(
5948        code,
5949        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
5950    )?;
5951    writeln!(code, "            if !quiet {{")?;
5952    writeln!(code, "                print!(\"{{}}\", text);")?;
5953    writeln!(code, "                std::io::stdout().flush().ok();")?;
5954    writeln!(code, "            }}")?;
5955    writeln!(code, "        }}")?;
5956    writeln!(code, "        generated_count += 1;")?;
5957    writeln!(code)?;
5958    writeln!(
5959        code,
5960        "        // Use profiling forward for first token when --profile is set"
5961    )?;
5962    writeln!(
5963        code,
5964        "        let logits = if profile && generated_count == 1 {{"
5965    )?;
5966    writeln!(code, "            model.forward_profile(next_token)")?;
5967    writeln!(code, "        }} else {{")?;
5968    writeln!(code, "            model.forward(next_token)")?;
5969    writeln!(code, "        }};")?;
5970    writeln!(code, "        next_token = argmax(&logits);")?;
5971    writeln!(code)?;
5972    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
5973    writeln!(code, "        if next_token == 2 {{")?;
5974    writeln!(code, "            break;")?;
5975    writeln!(code, "        }}")?;
5976    writeln!(code)?;
5977    writeln!(
5978        code,
5979        "        // Yield between tokens to reduce sustained GPU thermal load."
5980    )?;
5981    writeln!(
5982        code,
5983        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
5984    )?;
5985    writeln!(
5986        code,
5987        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
5988    )?;
5989    writeln!(
5990        code,
5991        "        // briefly, providing a micro-break that helps sustain peak throughput."
5992    )?;
5993    writeln!(code, "        std::thread::yield_now();")?;
5994    writeln!(code, "    }}")?;
5995    writeln!(code, "    if !quiet {{")?;
5996    writeln!(code, "        println!();")?;
5997    writeln!(code, "    }}")?;
5998    writeln!(
5999        code,
6000        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6001    )?;
6002    writeln!(
6003        code,
6004        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6005    )?;
6006    writeln!(
6007        code,
6008        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6009    )?;
6010    writeln!(code, "}}")?;
6011    writeln!(code)?;
6012
6013    // -- argmax helper --
6014    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6015    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6016    writeln!(code, "    logits.iter()")?;
6017    writeln!(code, "        .enumerate()")?;
6018    writeln!(
6019        code,
6020        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6021    )?;
6022    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6023    writeln!(code, "        .unwrap_or(0)")?;
6024    writeln!(code, "}}")?;
6025    writeln!(code)?;
6026
6027    // -- Request/Response types for OpenAI API --
6028    writeln!(
6029        code,
6030        "// -----------------------------------------------------------------------"
6031    )?;
6032    writeln!(code, "// OpenAI-compatible API server")?;
6033    writeln!(
6034        code,
6035        "// -----------------------------------------------------------------------"
6036    )?;
6037    writeln!(code)?;
6038    writeln!(code, "#[derive(Deserialize)]")?;
6039    writeln!(code, "struct ChatRequest {{")?;
6040    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6041    writeln!(code, "    #[serde(default)]")?;
6042    writeln!(code, "    stream: Option<bool>,")?;
6043    writeln!(code, "    #[serde(default)]")?;
6044    writeln!(code, "    max_tokens: Option<usize>,")?;
6045    writeln!(code, "    #[serde(default)]")?;
6046    writeln!(code, "    temperature: Option<f32>,")?;
6047    writeln!(code, "    #[serde(default)]")?;
6048    writeln!(code, "    model: Option<String>,")?;
6049    writeln!(code, "}}")?;
6050    writeln!(code)?;
6051    writeln!(code, "#[derive(Deserialize)]")?;
6052    writeln!(code, "struct ChatMessage {{")?;
6053    writeln!(code, "    role: String,")?;
6054    writeln!(code, "    content: String,")?;
6055    writeln!(code, "}}")?;
6056    writeln!(code)?;
6057
6058    // -- format_chat_messages --
6059    writeln!(
6060        code,
6061        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6062    )?;
6063    writeln!(code, "    let mut prompt = String::new();")?;
6064    writeln!(code, "    for msg in messages {{")?;
6065    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6066    writeln!(code, "    }}")?;
6067    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6068    writeln!(code, "    prompt")?;
6069    writeln!(code, "}}")?;
6070    writeln!(code)?;
6071
6072    // -- prefill helper --
6073    writeln!(
6074        code,
6075        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6076    )?;
6077    writeln!(code, "    let len = tokens.len();")?;
6078    writeln!(code, "    if len > 1 {{")?;
6079    writeln!(
6080        code,
6081        "        model.forward_prefill_batch(&tokens[..len - 1]);"
6082    )?;
6083    writeln!(code, "    }}")?;
6084    writeln!(code, "    model.forward(tokens[len - 1])")?;
6085    writeln!(code, "}}")?;
6086    writeln!(code)?;
6087
6088    // -- serve function --
6089    writeln!(
6090        code,
6091        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6092    )?;
6093    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6094    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
6095    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6096    writeln!(
6097        code,
6098        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6099    )?;
6100    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
6101    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
6102    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
6103    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
6104    writeln!(code)?;
6105    writeln!(code, "    for request in server.incoming_requests() {{")?;
6106    writeln!(code, "        let method = request.method().to_string();")?;
6107    writeln!(code, "        let url = request.url().to_string();")?;
6108    writeln!(code)?;
6109    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
6110
6111    // -- POST /v1/chat/completions --
6112    writeln!(
6113        code,
6114        "            (\"POST\", \"/v1/chat/completions\") => {{"
6115    )?;
6116    writeln!(
6117        code,
6118        "                handle_chat_completion(&mut model, &tokenizer, request);"
6119    )?;
6120    writeln!(code, "            }}")?;
6121
6122    // -- GET /v1/models --
6123    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
6124    writeln!(code, "                let body = serde_json::json!({{")?;
6125    writeln!(code, "                    \"object\": \"list\",")?;
6126    writeln!(code, "                    \"data\": [{{")?;
6127    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
6128    writeln!(code, "                        \"object\": \"model\",")?;
6129    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
6130    writeln!(code, "                    }}]")?;
6131    writeln!(code, "                }});")?;
6132    writeln!(
6133        code,
6134        "                let resp = tiny_http::Response::from_string(body.to_string())"
6135    )?;
6136    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6137    writeln!(code, "                request.respond(resp).ok();")?;
6138    writeln!(code, "            }}")?;
6139
6140    // -- GET /health --
6141    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
6142    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6143    writeln!(code, "                request.respond(resp).ok();")?;
6144    writeln!(code, "            }}")?;
6145
6146    // -- 404 --
6147    writeln!(code, "            _ => {{")?;
6148    writeln!(
6149        code,
6150        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
6151    )?;
6152    writeln!(code, "                    .with_status_code(404);")?;
6153    writeln!(code, "                request.respond(resp).ok();")?;
6154    writeln!(code, "            }}")?;
6155    writeln!(code, "        }}")?;
6156    writeln!(code, "    }}")?;
6157    writeln!(code, "}}")?;
6158    writeln!(code)?;
6159
6160    // -- handle_chat_completion --
6161    writeln!(code, "fn handle_chat_completion(")?;
6162    writeln!(code, "    model: &mut model::MetalModel,")?;
6163    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6164    writeln!(code, "    mut request: tiny_http::Request,")?;
6165    writeln!(code, ") {{")?;
6166    writeln!(code, "    // Read request body")?;
6167    writeln!(code, "    let mut body = String::new();")?;
6168    writeln!(
6169        code,
6170        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6171    )?;
6172    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6173    writeln!(code, "            .with_status_code(400);")?;
6174    writeln!(code, "        request.respond(resp).ok();")?;
6175    writeln!(code, "        return;")?;
6176    writeln!(code, "    }}")?;
6177    writeln!(code)?;
6178    writeln!(code, "    // Parse JSON")?;
6179    writeln!(
6180        code,
6181        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6182    )?;
6183    writeln!(code, "        Ok(r) => r,")?;
6184    writeln!(code, "        Err(e) => {{")?;
6185    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6186    writeln!(code, "                .with_status_code(400);")?;
6187    writeln!(code, "            request.respond(resp).ok();")?;
6188    writeln!(code, "            return;")?;
6189    writeln!(code, "        }}")?;
6190    writeln!(code, "    }};")?;
6191    writeln!(code)?;
6192    writeln!(
6193        code,
6194        "    let prompt = format_chat_messages(&req.messages);"
6195    )?;
6196    writeln!(
6197        code,
6198        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6199    )?;
6200    writeln!(code, "        Ok(e) => e,")?;
6201    writeln!(code, "        Err(e) => {{")?;
6202    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6203    writeln!(code, "                .with_status_code(500);")?;
6204    writeln!(code, "            request.respond(resp).ok();")?;
6205    writeln!(code, "            return;")?;
6206    writeln!(code, "        }}")?;
6207    writeln!(code, "    }};")?;
6208    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6209    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6210    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6211    writeln!(
6212        code,
6213        "    let _temperature = req.temperature.unwrap_or(1.0);"
6214    )?;
6215    writeln!(code)?;
6216
6217    // -- Reset KV cache for each request --
6218    writeln!(code, "    model.reset();")?;
6219    writeln!(code)?;
6220
6221    // -- Prefill with timing --
6222    writeln!(code, "    let prefill_start = Instant::now();")?;
6223    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6224    writeln!(
6225        code,
6226        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6227    )?;
6228    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6229    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6230    writeln!(code)?;
6231
6232    writeln!(code, "    if stream {{")?;
6233
6234    // -- SSE streaming response --
6235    writeln!(
6236        code,
6237        "        // SSE streaming: generate tokens and build SSE body"
6238    )?;
6239    writeln!(code, "        let gen_start = Instant::now();")?;
6240    writeln!(code, "        let mut generated_count: usize = 0;")?;
6241    writeln!(code, "        let mut sse_body = String::new();")?;
6242    writeln!(code, "        for _ in 0..max_tokens {{")?;
6243    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6244    writeln!(
6245        code,
6246        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6247    )?;
6248    writeln!(
6249        code,
6250        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6251    )?;
6252    writeln!(
6253        code,
6254        "                // escaped includes surrounding quotes, strip them"
6255    )?;
6256    writeln!(
6257        code,
6258        "                let inner = &escaped[1..escaped.len()-1];"
6259    )?;
6260    writeln!(code, "                sse_body.push_str(&format!(")?;
6261    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6262    writeln!(code, "                    inner")?;
6263    writeln!(code, "                ));")?;
6264    writeln!(code, "            }}")?;
6265    writeln!(code, "            generated_count += 1;")?;
6266    writeln!(code, "            let logits = model.forward(next_token);")?;
6267    writeln!(code, "            next_token = argmax(&logits);")?;
6268    writeln!(code, "        }}")?;
6269    writeln!(
6270        code,
6271        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6272    )?;
6273    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6274    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
6275    writeln!(code)?;
6276    writeln!(
6277        code,
6278        "        // Final chunk with finish_reason, timing, and DONE sentinel"
6279    )?;
6280    writeln!(code, "        sse_body.push_str(&format!(")?;
6281    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6282    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6283    writeln!(code, "        ));")?;
6284    writeln!(code)?;
6285    writeln!(
6286        code,
6287        "        let resp = tiny_http::Response::from_string(sse_body)"
6288    )?;
6289    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6290    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6291    writeln!(code, "        request.respond(resp).ok();")?;
6292
6293    writeln!(code, "    }} else {{")?;
6294
6295    // -- Non-streaming response --
6296    writeln!(
6297        code,
6298        "        // Non-streaming: generate all tokens, return JSON"
6299    )?;
6300    writeln!(code, "        let gen_start = Instant::now();")?;
6301    writeln!(code, "        let mut generated_count: usize = 0;")?;
6302    writeln!(code, "        let mut generated = String::new();")?;
6303    writeln!(code, "        for _ in 0..max_tokens {{")?;
6304    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6305    writeln!(
6306        code,
6307        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6308    )?;
6309    writeln!(code, "                generated.push_str(&text);")?;
6310    writeln!(code, "            }}")?;
6311    writeln!(code, "            generated_count += 1;")?;
6312    writeln!(code, "            let logits = model.forward(next_token);")?;
6313    writeln!(code, "            next_token = argmax(&logits);")?;
6314    writeln!(code, "        }}")?;
6315    writeln!(
6316        code,
6317        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6318    )?;
6319    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6320    writeln!(code)?;
6321    writeln!(code, "        let resp_json = serde_json::json!({{")?;
6322    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
6323    writeln!(code, "            \"object\": \"chat.completion\",")?;
6324    writeln!(code, "            \"choices\": [{{")?;
6325    writeln!(code, "                \"index\": 0,")?;
6326    writeln!(code, "                \"message\": {{")?;
6327    writeln!(code, "                    \"role\": \"assistant\",")?;
6328    writeln!(code, "                    \"content\": generated")?;
6329    writeln!(code, "                }},")?;
6330    writeln!(code, "                \"finish_reason\": \"stop\"")?;
6331    writeln!(code, "            }}],")?;
6332    writeln!(code, "            \"usage\": {{")?;
6333    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
6334    writeln!(
6335        code,
6336        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
6337    )?;
6338    writeln!(
6339        code,
6340        "                \"generation_tokens\": generated_count,"
6341    )?;
6342    writeln!(
6343        code,
6344        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
6345    )?;
6346    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
6347    writeln!(code, "            }}")?;
6348    writeln!(code, "        }});")?;
6349    writeln!(
6350        code,
6351        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
6352    )?;
6353    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6354    writeln!(code, "        request.respond(resp).ok();")?;
6355    writeln!(code, "    }}")?;
6356    writeln!(code, "}}")?;
6357
6358    Ok(code)
6359}
6360
6361// ---------------------------------------------------------------------------
6362// Tests
6363// ---------------------------------------------------------------------------
6364
6365#[cfg(test)]
6366mod tests {
6367    use super::*;
6368    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
6369
6370    fn minimal_config() -> ModelConfig {
6371        ModelConfig {
6372            architecture: Architecture::Llama,
6373            hidden_size: 64,
6374            intermediate_size: 128,
6375            num_layers: 2,
6376            num_attention_heads: 4,
6377            num_kv_heads: 4,
6378            head_dim: 16,
6379            vocab_size: 256,
6380            max_seq_len: 512,
6381            rms_norm_eps: 1e-5,
6382            rope_theta: 10000.0,
6383            dtype: DType::F32,
6384            sliding_window_size: None,
6385            qkv_bias: false,
6386        }
6387    }
6388
6389    fn minimal_graph() -> Graph {
6390        Graph::new("test-metal").with_config(minimal_config())
6391    }
6392
6393    #[test]
6394    fn generate_metal_project_creates_files() {
6395        let dir = tempfile::tempdir().unwrap();
6396        let graph = minimal_graph();
6397        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
6398
6399        assert!(
6400            dir.path().join("Cargo.toml").exists(),
6401            "Cargo.toml should be created"
6402        );
6403        assert!(
6404            dir.path().join("src/model.rs").exists(),
6405            "src/model.rs should be created"
6406        );
6407        assert!(
6408            dir.path().join("src/main.rs").exists(),
6409            "src/main.rs should be created"
6410        );
6411        assert!(
6412            dir.path().join("shaders/kernels.metal").exists(),
6413            "shaders/kernels.metal should be created"
6414        );
6415    }
6416
6417    #[test]
6418    fn generated_cargo_toml_has_metal_dep() {
6419        let toml = generate_cargo_toml("my-model");
6420        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
6421        assert!(
6422            toml.contains("tokenizers"),
6423            "Cargo.toml should depend on tokenizers"
6424        );
6425        assert!(
6426            toml.contains("memmap2"),
6427            "Cargo.toml should depend on memmap2"
6428        );
6429        assert!(toml.contains("half"), "Cargo.toml should depend on half");
6430    }
6431
6432    #[test]
6433    fn generated_model_rs_contains_metal_code() {
6434        let config = minimal_config();
6435        let model_rs = generate_model_rs(&config).unwrap();
6436
6437        assert!(
6438            model_rs.contains("pub struct MetalModel"),
6439            "model.rs should define MetalModel struct"
6440        );
6441        assert!(
6442            model_rs.contains("matmul_pipeline: ComputePipelineState"),
6443            "MetalModel should have matmul_pipeline field"
6444        );
6445        assert!(
6446            model_rs.contains("Device::system_default()"),
6447            "model.rs should use Metal device"
6448        );
6449        assert!(
6450            model_rs.contains("new_library_with_source"),
6451            "model.rs should compile Metal shaders"
6452        );
6453        assert!(
6454            model_rs.contains("fn new(weights: &[u8])"),
6455            "MetalModel should implement new()"
6456        );
6457        assert!(
6458            model_rs.contains("fn forward(&mut self, token_id: u32)"),
6459            "MetalModel should implement forward()"
6460        );
6461    }
6462
6463    #[test]
6464    fn generated_shaders_contain_kernel_names() {
6465        let shaders = generate_metal_shaders(&minimal_config());
6466
6467        assert!(
6468            shaders.contains("kernel void matmul_vec"),
6469            "shaders should contain matmul_vec kernel"
6470        );
6471        assert!(
6472            shaders.contains("kernel void rms_norm"),
6473            "shaders should contain rms_norm kernel"
6474        );
6475        assert!(
6476            shaders.contains("kernel void rope"),
6477            "shaders should contain rope kernel"
6478        );
6479        assert!(
6480            shaders.contains("kernel void softmax"),
6481            "shaders should contain softmax kernel"
6482        );
6483        assert!(
6484            shaders.contains("kernel void silu_mul("),
6485            "shaders should contain silu_mul kernel"
6486        );
6487        assert!(
6488            shaders.contains("kernel void silu_mul_fused"),
6489            "shaders should contain silu_mul_fused kernel"
6490        );
6491        assert!(
6492            shaders.contains("kernel void elementwise_add"),
6493            "shaders should contain elementwise_add kernel"
6494        );
6495        assert!(
6496            shaders.contains("kernel void attention"),
6497            "shaders should contain attention kernel"
6498        );
6499        assert!(
6500            shaders.contains("kernel void add_inplace"),
6501            "shaders should contain add_inplace kernel"
6502        );
6503        assert!(
6504            shaders.contains("kernel void copy_buffer"),
6505            "shaders should contain copy_buffer kernel"
6506        );
6507        assert!(
6508            shaders.contains("kernel void copy_offset"),
6509            "shaders should contain copy_offset kernel"
6510        );
6511    }
6512
6513    #[test]
6514    fn generated_shaders_use_simdgroup_features() {
6515        let shaders = generate_metal_shaders(&minimal_config());
6516
6517        assert!(
6518            shaders.contains("threadgroup_barrier"),
6519            "shaders should use threadgroup barriers"
6520        );
6521        assert!(
6522            shaders.contains("threadgroup float"),
6523            "shaders should use threadgroup shared memory"
6524        );
6525        assert!(
6526            shaders.contains("thread_index_in_threadgroup"),
6527            "shaders should use threadgroup indexing"
6528        );
6529        assert!(
6530            shaders.contains("simd_sum"),
6531            "shaders should use simd_sum for warp-level reduction"
6532        );
6533        assert!(
6534            shaders.contains("simd_max"),
6535            "attention kernel should use simd_max for cooperative softmax"
6536        );
6537        assert!(
6538            shaders.contains("thread_index_in_simdgroup"),
6539            "shaders should use simdgroup lane indexing"
6540        );
6541        assert!(
6542            shaders.contains("simdgroup_index_in_threadgroup"),
6543            "shaders should use simdgroup indexing within threadgroup"
6544        );
6545        assert!(
6546            shaders.contains("float4"),
6547            "matmul_vec should use float4 vectorized loads"
6548        );
6549    }
6550
6551    #[test]
6552    fn generated_main_rs_has_tokenizer_usage() {
6553        let config = minimal_config();
6554        let main_rs = generate_main_rs("test-model", &config).unwrap();
6555
6556        assert!(
6557            main_rs.contains("tokenizers::Tokenizer"),
6558            "main.rs should use tokenizers crate"
6559        );
6560        assert!(
6561            main_rs.contains("MetalModel::new"),
6562            "main.rs should call MetalModel::new"
6563        );
6564        assert!(
6565            main_rs.contains("model.forward"),
6566            "main.rs should call model.forward"
6567        );
6568        assert!(
6569            main_rs.contains("memmap2"),
6570            "main.rs should use memmap2 for zero-copy weight loading"
6571        );
6572    }
6573
6574    #[test]
6575    fn missing_config_returns_error() {
6576        let dir = tempfile::tempdir().unwrap();
6577        let graph = Graph::new("no-config");
6578        let result = generate_metal_project(&graph, dir.path(), "fail");
6579        assert!(
6580            matches!(result, Err(MetalCodegenError::MissingConfig)),
6581            "should fail with MissingConfig when graph has no config"
6582        );
6583    }
6584
6585    #[test]
6586    fn sanitize_name_works() {
6587        assert_eq!(sanitize_name("My Model!"), "my-model");
6588        assert_eq!(sanitize_name("test_model"), "test-model");
6589        assert_eq!(sanitize_name("simple"), "simple");
6590    }
6591
6592    #[test]
6593    fn generated_forward_uses_single_command_buffer() {
6594        let config = minimal_config();
6595        let model_rs = generate_model_rs(&config).unwrap();
6596
6597        // The forward function should create exactly one command buffer.
6598        // Use the exact signature to avoid matching forward_prefill/forward_profile.
6599        let forward_start = model_rs
6600            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6601            .unwrap();
6602        let forward_body = &model_rs[forward_start..];
6603        // End at the next pub/private method
6604        let forward_end = forward_body
6605            .find("\n    pub fn forward_profile")
6606            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
6607            .or_else(|| forward_body.find("\n    fn dispatch_"))
6608            .unwrap_or(forward_body.len());
6609        let forward_code = &forward_body[..forward_end];
6610
6611        // Should have exactly one new_command_buffer call
6612        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
6613        assert_eq!(
6614            cmd_buf_count, 1,
6615            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
6616        );
6617
6618        // Should have exactly one commit call
6619        let commit_count = forward_code.matches("cmd.commit()").count();
6620        assert_eq!(
6621            commit_count, 1,
6622            "forward() should commit exactly once, found {commit_count}"
6623        );
6624
6625        // Should wait: once for cmd + possibly once for prev_cmd drain
6626        let wait_count = forward_code.matches("wait_until_completed()").count();
6627        assert!(
6628            wait_count >= 1 && wait_count <= 2,
6629            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
6630        );
6631    }
6632
6633    #[test]
6634    fn generated_model_has_preallocated_working_buffers() {
6635        let config = minimal_config();
6636        let model_rs = generate_model_rs(&config).unwrap();
6637
6638        for buf_name in &[
6639            "normed_buf",
6640            "qkv_buf",
6641            "attn_out_buf",
6642            "attn_proj_buf",
6643            "gate_up_buf",
6644            "ffn_hidden_buf",
6645            "ffn_out_buf",
6646            "add_tmp_buf",
6647        ] {
6648            assert!(
6649                model_rs.contains(&format!("{buf_name}: Buffer")),
6650                "MetalModel should have pre-allocated {buf_name} field"
6651            );
6652        }
6653    }
6654
6655    #[test]
6656    fn generated_dispatch_helpers_take_compute_encoder_ref() {
6657        let config = minimal_config();
6658        let model_rs = generate_model_rs(&config).unwrap();
6659
6660        for method in &[
6661            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
6662            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
6663            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
6664            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
6665            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
6666            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
6667            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
6668            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
6669            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
6670            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
6671            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
6672            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
6673            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
6674        ] {
6675            assert!(
6676                model_rs.contains(method),
6677                "model.rs should contain dispatch helper: {method}"
6678            );
6679        }
6680    }
6681
6682    #[test]
6683    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
6684        let config = minimal_config();
6685        let model_rs = generate_model_rs(&config).unwrap();
6686
6687        // Find dispatch helpers section and check none create their own encoders
6688        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
6689        let helpers_code = &model_rs[helpers_start..];
6690
6691        // None of the dispatch_ helpers should call new_command_buffer
6692        assert!(
6693            !helpers_code.contains("self.queue.new_command_buffer()"),
6694            "dispatch helpers should not create their own command buffers"
6695        );
6696
6697        // None should create their own compute encoders
6698        assert!(
6699            !helpers_code.contains("new_compute_command_encoder()"),
6700            "dispatch helpers should not create their own compute encoders"
6701        );
6702
6703        // None should call end_encoding
6704        assert!(
6705            !helpers_code.contains("end_encoding()"),
6706            "dispatch helpers should not call end_encoding"
6707        );
6708
6709        // None should call commit or wait
6710        assert!(
6711            !helpers_code.contains(".commit()"),
6712            "dispatch helpers should not commit command buffers"
6713        );
6714        assert!(
6715            !helpers_code.contains("wait_until_completed"),
6716            "dispatch helpers should not wait on command buffers"
6717        );
6718    }
6719
6720    #[test]
6721    fn generated_forward_batches_compute_encoders() {
6722        let config = minimal_config();
6723        let model_rs = generate_model_rs(&config).unwrap();
6724
6725        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
6726        let forward_start = model_rs
6727            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6728            .unwrap();
6729        let forward_body = &model_rs[forward_start..];
6730        let forward_end = forward_body
6731            .find("\n    pub fn forward_profile")
6732            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
6733            .or_else(|| forward_body.find("\n    fn dispatch_"))
6734            .unwrap_or(forward_body.len());
6735        let forward_code = &forward_body[..forward_end];
6736
6737        // Forward should not allocate new buffers
6738        assert!(
6739            !forward_code.contains("device.new_buffer"),
6740            "forward() should not allocate new buffers per call"
6741        );
6742
6743        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
6744        // Copy operations use compute copy kernels instead of blit encoders.
6745        let compute_encoder_count = forward_code
6746            .matches("new_compute_command_encoder()")
6747            .count();
6748        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
6749
6750        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
6751        assert_eq!(
6752            compute_encoder_count, 1,
6753            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
6754        );
6755        assert_eq!(
6756            blit_encoder_count, 0,
6757            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
6758        );
6759    }
6760
6761    #[test]
6762    fn generated_forward_uses_add_inplace() {
6763        let config = minimal_config();
6764        let model_rs = generate_model_rs(&config).unwrap();
6765
6766        // Should use in-place add (no blit copy-back needed)
6767        assert!(
6768            model_rs.contains("dispatch_add_inplace"),
6769            "forward() should use dispatch_add_inplace for residual connections"
6770        );
6771        assert!(
6772            model_rs.contains("add_inplace_pipeline"),
6773            "MetalModel should have add_inplace_pipeline"
6774        );
6775    }
6776
6777    fn minimal_q8_config() -> ModelConfig {
6778        ModelConfig {
6779            architecture: Architecture::Llama,
6780            hidden_size: 64,
6781            intermediate_size: 128,
6782            num_layers: 2,
6783            num_attention_heads: 4,
6784            num_kv_heads: 4,
6785            head_dim: 16,
6786            vocab_size: 256,
6787            max_seq_len: 512,
6788            rms_norm_eps: 1e-5,
6789            rope_theta: 10000.0,
6790            dtype: DType::Q8_0,
6791            sliding_window_size: None,
6792            qkv_bias: false,
6793        }
6794    }
6795
6796    #[test]
6797    fn generated_shaders_contain_q8_kernel() {
6798        let shaders = generate_metal_shaders(&minimal_config());
6799
6800        assert!(
6801            shaders.contains("kernel void matmul_vec_q8"),
6802            "shaders should contain matmul_vec_q8 kernel"
6803        );
6804        assert!(
6805            shaders.contains("device const uchar* matrix"),
6806            "matmul_vec_q8 should accept raw Q8_0 bytes"
6807        );
6808        assert!(
6809            shaders.contains("packed_short4"),
6810            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
6811        );
6812        assert!(
6813            shaders.contains("as_type<char2>"),
6814            "matmul_vec_q8 should bitcast short lanes to char2"
6815        );
6816        assert!(
6817            shaders.contains("device const half*"),
6818            "matmul_vec_q8 should read f16 scale via half pointer"
6819        );
6820    }
6821
6822    #[test]
6823    fn generated_model_uses_fused_qkv_projections() {
6824        let config = minimal_config();
6825        let model_rs = generate_model_rs(&config).unwrap();
6826
6827        // Should have fused QKV weight in layer buffers
6828        assert!(
6829            model_rs.contains("qkv_weight: Buffer"),
6830            "LayerBuffers should have fused qkv_weight field"
6831        );
6832        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
6833        assert!(
6834            !model_rs.contains("    q_weight: Buffer"),
6835            "LayerBuffers should not have separate q_weight field"
6836        );
6837        assert!(
6838            !model_rs.contains("    k_weight: Buffer"),
6839            "LayerBuffers should not have separate k_weight field"
6840        );
6841        assert!(
6842            !model_rs.contains("    v_weight: Buffer"),
6843            "LayerBuffers should not have separate v_weight field"
6844        );
6845
6846        // Should have fused gate_up_weight
6847        assert!(
6848            model_rs.contains("gate_up_weight: Buffer"),
6849            "LayerBuffers should have fused gate_up_weight field"
6850        );
6851        // Should NOT have separate gate/up weight fields
6852        assert!(
6853            !model_rs.contains("    gate_weight: Buffer"),
6854            "LayerBuffers should not have separate gate_weight field"
6855        );
6856        assert!(
6857            !model_rs.contains("    up_weight: Buffer"),
6858            "LayerBuffers should not have separate up_weight field"
6859        );
6860
6861        // Should have fused working buffers
6862        assert!(
6863            model_rs.contains("qkv_buf: Buffer"),
6864            "MetalModel should have fused qkv_buf"
6865        );
6866        assert!(
6867            model_rs.contains("gate_up_buf: Buffer"),
6868            "MetalModel should have fused gate_up_buf"
6869        );
6870
6871        // Forward pass should use fused dispatch
6872        assert!(
6873            model_rs.contains("dispatch_silu_mul_fused"),
6874            "forward pass should use dispatch_silu_mul_fused"
6875        );
6876        assert!(
6877            model_rs.contains("dispatch_rope_offset"),
6878            "forward pass should use dispatch_rope_offset for fused QKV"
6879        );
6880        assert!(
6881            model_rs.contains("dispatch_attention_offset"),
6882            "forward pass should use dispatch_attention_offset for fused QKV"
6883        );
6884    }
6885
6886    #[test]
6887    fn q8_model_has_matmul_q8_pipeline() {
6888        let config = minimal_q8_config();
6889        let model_rs = generate_model_rs(&config).unwrap();
6890
6891        assert!(
6892            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
6893            "MetalModel should have matmul_q8_pipeline field"
6894        );
6895        assert!(
6896            model_rs.contains("matmul_q8_pipeline,"),
6897            "MetalModel Self should include matmul_q8_pipeline"
6898        );
6899    }
6900
6901    #[test]
6902    fn q8_model_uses_dispatch_matmul_q8() {
6903        let config = minimal_q8_config();
6904        let model_rs = generate_model_rs(&config).unwrap();
6905
6906        assert!(
6907            model_rs.contains("dispatch_matmul_q8"),
6908            "Q8_0 model should use dispatch_matmul_q8 for projections"
6909        );
6910        assert!(
6911            model_rs.contains("fn dispatch_matmul_q8"),
6912            "model.rs should define dispatch_matmul_q8 method"
6913        );
6914    }
6915
6916    #[test]
6917    fn q8_model_loads_raw_bytes_not_dequantized() {
6918        let config = minimal_q8_config();
6919        let model_rs = generate_model_rs(&config).unwrap();
6920
6921        // Should NOT contain dequantization code
6922        assert!(
6923            !model_rs.contains("f16_to_f32"),
6924            "Q8_0 model should not dequantize weights to f32"
6925        );
6926        assert!(
6927            !model_rs.contains("f32_data"),
6928            "Q8_0 model should not create f32 weight data"
6929        );
6930
6931        // Should load raw Q8_0 bytes directly
6932        assert!(
6933            model_rs.contains("total_raw as u64"),
6934            "Q8_0 model should load raw bytes into Metal buffer"
6935        );
6936    }
6937
6938    #[test]
6939    fn q8_model_norms_stay_f32() {
6940        let config = minimal_q8_config();
6941        let model_rs = generate_model_rs(&config).unwrap();
6942
6943        // Norm weights should still use f32 buffers
6944        assert!(
6945            model_rs.contains("let attn_norm = next_f32_buffer"),
6946            "attn_norm should use f32 buffer even for Q8_0 models"
6947        );
6948        assert!(
6949            model_rs.contains("let ffn_norm = next_f32_buffer"),
6950            "ffn_norm should use f32 buffer even for Q8_0 models"
6951        );
6952        assert!(
6953            model_rs.contains("let norm_buf = next_f32_buffer"),
6954            "final norm should use f32 buffer even for Q8_0 models"
6955        );
6956    }
6957
6958    #[test]
6959    fn q8_model_uses_fused_weight_loading() {
6960        let config = minimal_q8_config();
6961        let model_rs = generate_model_rs(&config).unwrap();
6962
6963        // Should use fused Q8 buffer loading for QKV
6964        assert!(
6965            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
6966            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
6967        );
6968        // Should use fused Q8 buffer loading for gate+up
6969        assert!(
6970            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
6971            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
6972        );
6973        // Should still use regular q8 buffer for individual weights
6974        assert!(
6975            model_rs.contains("let o_weight = next_q8_buffer"),
6976            "Q8_0 model should use next_q8_buffer for O weight"
6977        );
6978        assert!(
6979            model_rs.contains("let down_weight = next_q8_buffer"),
6980            "Q8_0 model should use next_q8_buffer for down weight"
6981        );
6982    }
6983
6984    #[test]
6985    fn f32_model_does_not_use_q8_dispatch() {
6986        let config = minimal_config();
6987        let model_rs = generate_model_rs(&config).unwrap();
6988
6989        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
6990        let forward_start = model_rs
6991            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6992            .unwrap();
6993        let forward_body = &model_rs[forward_start..];
6994        let forward_end = forward_body
6995            .find("\n    fn dispatch_")
6996            .unwrap_or(forward_body.len());
6997        let forward_code = &forward_body[..forward_end];
6998
6999        assert!(
7000            !forward_code.contains("dispatch_matmul_q8"),
7001            "f32 model forward should not use dispatch_matmul_q8"
7002        );
7003    }
7004
7005    #[test]
7006    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7007        let config = minimal_q8_config();
7008        let model_rs = generate_model_rs(&config).unwrap();
7009
7010        assert!(
7011            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7012            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7013        );
7014    }
7015
7016    #[test]
7017    fn generated_model_has_double_buffered_prefill() {
7018        let config = minimal_config();
7019        let model_rs = generate_model_rs(&config).unwrap();
7020
7021        // MetalModel should have prev_cmd field for double-buffered prefill
7022        assert!(
7023            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7024            "MetalModel should have prev_cmd field for double-buffered prefill"
7025        );
7026
7027        // Should have forward_prefill method
7028        assert!(
7029            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7030            "MetalModel should have forward_prefill method"
7031        );
7032
7033        // forward() should drain prev_cmd at the start
7034        assert!(
7035            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7036            "forward() should drain prev_cmd from previous prefill"
7037        );
7038    }
7039
7040    #[test]
7041    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7042        let config = minimal_config();
7043        let main_rs = generate_main_rs("test-model", &config).unwrap();
7044
7045        assert!(
7046            main_rs.contains("forward_prefill"),
7047            "main.rs should use forward_prefill for intermediate prompt tokens"
7048        );
7049        assert!(
7050            main_rs.contains("double-buffered"),
7051            "main.rs should document double-buffered prefill"
7052        );
7053    }
7054
7055    #[test]
7056    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7057        let shaders = generate_metal_shaders(&minimal_config());
7058
7059        assert!(
7060            shaders.contains("packed_short4"),
7061            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7062        );
7063        assert!(
7064            shaders.contains("d0[0]"),
7065            "matmul_vec_q8 should index the wide pointer for row 0"
7066        );
7067        assert!(
7068            shaders.contains("as_type<char2>"),
7069            "matmul_vec_q8 should bitcast short lanes to char2"
7070        );
7071        assert!(
7072            shaders.contains("dot("),
7073            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7074        );
7075    }
7076
7077    // ── Q4_0 tests ──────────────────────────────────────────────────────
7078
7079    fn minimal_q4_config() -> ModelConfig {
7080        ModelConfig {
7081            architecture: Architecture::Llama,
7082            hidden_size: 64,
7083            intermediate_size: 128,
7084            num_layers: 2,
7085            num_attention_heads: 4,
7086            num_kv_heads: 4,
7087            head_dim: 16,
7088            vocab_size: 256,
7089            max_seq_len: 512,
7090            rms_norm_eps: 1e-5,
7091            rope_theta: 10000.0,
7092            dtype: DType::Q4_0,
7093            sliding_window_size: None,
7094            qkv_bias: false,
7095        }
7096    }
7097
7098    #[test]
7099    fn generated_shaders_contain_q4_kernel() {
7100        let shaders = generate_metal_shaders(&minimal_config());
7101
7102        assert!(
7103            shaders.contains("kernel void matmul_vec_q4"),
7104            "shaders should contain matmul_vec_q4 kernel"
7105        );
7106        assert!(
7107            shaders.contains("Q4_ROWS_PER_TG"),
7108            "shaders should define Q4_ROWS_PER_TG constant"
7109        );
7110        assert!(
7111            shaders.contains("Q4_ROWS_PER_SG"),
7112            "shaders should define Q4_ROWS_PER_SG constant"
7113        );
7114    }
7115
7116    #[test]
7117    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7118        let shaders = generate_metal_shaders(&minimal_config());
7119
7120        // Q4_0 kernel should use uchar4 for packed byte loads
7121        assert!(
7122            shaders.contains("uchar4"),
7123            "matmul_vec_q4 should use uchar4 for packed byte loads"
7124        );
7125        // Should unpack nibbles with &0xF and >>4
7126        assert!(
7127            shaders.contains("&0xF"),
7128            "matmul_vec_q4 should extract low nibble with &0xF"
7129        );
7130        assert!(
7131            shaders.contains(">>4"),
7132            "matmul_vec_q4 should extract high nibble with >>4"
7133        );
7134        // Should subtract 8 to convert unsigned to signed
7135        assert!(
7136            shaders.contains("-8)"),
7137            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7138        );
7139        // Should use 18-byte block size
7140        assert!(
7141            shaders.contains("blk * 18"),
7142            "matmul_vec_q4 should use 18-byte block stride"
7143        );
7144    }
7145
7146    #[test]
7147    fn q4_model_has_matmul_q4_pipeline() {
7148        let config = minimal_q4_config();
7149        let model_rs = generate_model_rs(&config).unwrap();
7150
7151        assert!(
7152            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7153            "MetalModel should have matmul_q4_pipeline field"
7154        );
7155        assert!(
7156            model_rs.contains("matmul_q4_pipeline,"),
7157            "MetalModel Self should include matmul_q4_pipeline"
7158        );
7159    }
7160
7161    #[test]
7162    fn q4_model_uses_dispatch_matmul_q4() {
7163        let config = minimal_q4_config();
7164        let model_rs = generate_model_rs(&config).unwrap();
7165
7166        assert!(
7167            model_rs.contains("dispatch_matmul_q4"),
7168            "Q4_0 model should use dispatch_matmul_q4 for projections"
7169        );
7170        assert!(
7171            model_rs.contains("fn dispatch_matmul_q4"),
7172            "model.rs should define dispatch_matmul_q4 method"
7173        );
7174    }
7175
7176    #[test]
7177    fn q4_model_loads_raw_bytes_not_dequantized() {
7178        let config = minimal_q4_config();
7179        let model_rs = generate_model_rs(&config).unwrap();
7180
7181        // Should NOT contain dequantization code
7182        assert!(
7183            !model_rs.contains("f16_to_f32"),
7184            "Q4_0 model should not dequantize weights to f32"
7185        );
7186
7187        // Should load raw Q4_0 bytes directly
7188        assert!(
7189            model_rs.contains("total_raw as u64"),
7190            "Q4_0 model should load raw bytes into Metal buffer"
7191        );
7192    }
7193
7194    #[test]
7195    fn q4_model_norms_stay_f32() {
7196        let config = minimal_q4_config();
7197        let model_rs = generate_model_rs(&config).unwrap();
7198
7199        assert!(
7200            model_rs.contains("let attn_norm = next_f32_buffer"),
7201            "attn_norm should use f32 buffer even for Q4_0 models"
7202        );
7203        assert!(
7204            model_rs.contains("let ffn_norm = next_f32_buffer"),
7205            "ffn_norm should use f32 buffer even for Q4_0 models"
7206        );
7207        assert!(
7208            model_rs.contains("let norm_buf = next_f32_buffer"),
7209            "final norm should use f32 buffer even for Q4_0 models"
7210        );
7211    }
7212
7213    #[test]
7214    fn q4_model_uses_fused_weight_loading() {
7215        let config = minimal_q4_config();
7216        let model_rs = generate_model_rs(&config).unwrap();
7217
7218        assert!(
7219            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7220            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7221        );
7222        assert!(
7223            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7224            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7225        );
7226        assert!(
7227            model_rs.contains("let o_weight = next_q4_buffer"),
7228            "Q4_0 model should use next_q4_buffer for O weight"
7229        );
7230        assert!(
7231            model_rs.contains("let down_weight = next_q4_buffer"),
7232            "Q4_0 model should use next_q4_buffer for down weight"
7233        );
7234    }
7235
7236    #[test]
7237    fn q4_dispatch_helper_takes_compute_encoder_ref() {
7238        let config = minimal_q4_config();
7239        let model_rs = generate_model_rs(&config).unwrap();
7240
7241        assert!(
7242            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
7243            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
7244        );
7245    }
7246
7247    #[test]
7248    fn f32_model_does_not_use_q4_dispatch() {
7249        let config = minimal_config();
7250        let model_rs = generate_model_rs(&config).unwrap();
7251
7252        let forward_start = model_rs
7253            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7254            .unwrap();
7255        let forward_body = &model_rs[forward_start..];
7256        let forward_end = forward_body
7257            .find("\n    fn dispatch_")
7258            .unwrap_or(forward_body.len());
7259        let forward_code = &forward_body[..forward_end];
7260
7261        assert!(
7262            !forward_code.contains("dispatch_matmul_q4"),
7263            "f32 model forward should not use dispatch_matmul_q4"
7264        );
7265    }
7266
7267    #[test]
7268    fn q4_model_lm_head_uses_q4_buffer() {
7269        let config = minimal_q4_config();
7270        let model_rs = generate_model_rs(&config).unwrap();
7271
7272        assert!(
7273            model_rs.contains("let lm_head_buf = next_q4_buffer"),
7274            "Q4_0 model should use next_q4_buffer for lm_head"
7275        );
7276    }
7277
7278    #[test]
7279    fn vec_tile_size_matches_model_dimensions() {
7280        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
7281        let small = minimal_config();
7282        let shaders_small = generate_metal_shaders(&small);
7283        assert!(
7284            shaders_small.contains("vec_tile[128]"),
7285            "vec_tile should be sized to max(hidden, intermediate) = 128"
7286        );
7287
7288        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
7289        let mut large = minimal_config();
7290        large.hidden_size = 2048;
7291        large.intermediate_size = 8192;
7292        let shaders_large = generate_metal_shaders(&large);
7293        assert!(
7294            shaders_large.contains("vec_tile[8192]"),
7295            "vec_tile should be 8192 for models with intermediate=8192"
7296        );
7297        assert!(
7298            !shaders_large.contains("vec_tile[4096]"),
7299            "vec_tile should NOT be hardcoded to 4096"
7300        );
7301    }
7302
7303    #[test]
7304    fn generated_cargo_toml_has_server_deps() {
7305        let toml = generate_cargo_toml("my-model");
7306        assert!(
7307            toml.contains("tiny_http"),
7308            "Cargo.toml should depend on tiny_http"
7309        );
7310        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
7311        assert!(
7312            toml.contains("serde_json"),
7313            "Cargo.toml should depend on serde_json"
7314        );
7315    }
7316
7317    #[test]
7318    fn generated_main_rs_has_serve_mode() {
7319        let config = minimal_config();
7320        let main_rs = generate_main_rs("test-model", &config).unwrap();
7321
7322        assert!(
7323            main_rs.contains("--serve"),
7324            "main.rs should parse --serve flag"
7325        );
7326        assert!(
7327            main_rs.contains("--port"),
7328            "main.rs should parse --port flag"
7329        );
7330        assert!(
7331            main_rs.contains("fn serve("),
7332            "main.rs should define serve function"
7333        );
7334        assert!(
7335            main_rs.contains("tiny_http::Server::http"),
7336            "main.rs should create tiny_http server"
7337        );
7338    }
7339
7340    #[test]
7341    fn generated_main_rs_has_chat_completions_endpoint() {
7342        let config = minimal_config();
7343        let main_rs = generate_main_rs("test-model", &config).unwrap();
7344
7345        assert!(
7346            main_rs.contains("/v1/chat/completions"),
7347            "main.rs should handle /v1/chat/completions endpoint"
7348        );
7349        assert!(
7350            main_rs.contains("/v1/models"),
7351            "main.rs should handle /v1/models endpoint"
7352        );
7353        assert!(
7354            main_rs.contains("/health"),
7355            "main.rs should handle /health endpoint"
7356        );
7357    }
7358
7359    #[test]
7360    fn generated_main_rs_has_sse_streaming() {
7361        let config = minimal_config();
7362        let main_rs = generate_main_rs("test-model", &config).unwrap();
7363
7364        assert!(
7365            main_rs.contains("text/event-stream"),
7366            "main.rs should set SSE content type for streaming"
7367        );
7368        assert!(
7369            main_rs.contains("chat.completion.chunk"),
7370            "main.rs should emit SSE chunks"
7371        );
7372        assert!(
7373            main_rs.contains("[DONE]"),
7374            "main.rs should emit [DONE] sentinel"
7375        );
7376    }
7377
7378    #[test]
7379    fn generated_main_rs_has_chat_message_formatting() {
7380        let config = minimal_config();
7381        let main_rs = generate_main_rs("test-model", &config).unwrap();
7382
7383        assert!(
7384            main_rs.contains("fn format_chat_messages"),
7385            "main.rs should define format_chat_messages function"
7386        );
7387        assert!(
7388            main_rs.contains("<|im_start|>"),
7389            "main.rs should use ChatML format"
7390        );
7391        assert!(
7392            main_rs.contains("<|im_end|>"),
7393            "main.rs should use ChatML format"
7394        );
7395    }
7396
7397    #[test]
7398    fn generated_main_rs_has_request_types() {
7399        let config = minimal_config();
7400        let main_rs = generate_main_rs("test-model", &config).unwrap();
7401
7402        assert!(
7403            main_rs.contains("struct ChatRequest"),
7404            "main.rs should define ChatRequest struct"
7405        );
7406        assert!(
7407            main_rs.contains("struct ChatMessage"),
7408            "main.rs should define ChatMessage struct"
7409        );
7410        assert!(
7411            main_rs.contains("Deserialize"),
7412            "main.rs should derive Deserialize for request types"
7413        );
7414    }
7415
7416    #[test]
7417    fn generated_model_has_reset_method() {
7418        let config = minimal_config();
7419        let model_rs = generate_model_rs(&config).unwrap();
7420
7421        assert!(
7422            model_rs.contains("pub fn reset(&mut self)"),
7423            "model.rs should have a reset() method for multi-request serving"
7424        );
7425        assert!(
7426            model_rs.contains("self.pos = 0"),
7427            "reset() should reset position to 0"
7428        );
7429    }
7430
7431    #[test]
7432    fn generated_main_rs_cli_mode_still_works() {
7433        let config = minimal_config();
7434        let main_rs = generate_main_rs("test-model", &config).unwrap();
7435
7436        // CLI mode should still be functional
7437        assert!(
7438            main_rs.contains("fn cli_mode("),
7439            "main.rs should define cli_mode function"
7440        );
7441        assert!(
7442            main_rs.contains("model.forward"),
7443            "main.rs should call model.forward"
7444        );
7445        assert!(
7446            main_rs.contains("model.forward_prefill"),
7447            "main.rs should call model.forward_prefill"
7448        );
7449    }
7450
7451    // ── Batched prefill tests ──────────────────────────────────────────
7452
7453    #[test]
7454    fn generated_shaders_contain_batch_kernels() {
7455        let shaders = generate_metal_shaders(&minimal_config());
7456
7457        assert!(
7458            shaders.contains("kernel void matmul_vec_batch"),
7459            "shaders should contain matmul_vec_batch kernel"
7460        );
7461        assert!(
7462            shaders.contains("kernel void matmul_vec_q8_batch"),
7463            "shaders should contain matmul_vec_q8_batch kernel"
7464        );
7465        assert!(
7466            shaders.contains("kernel void matmul_q8_gemm_batch"),
7467            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
7468        );
7469        assert!(
7470            shaders.contains("kernel void matmul_vec_q4_batch"),
7471            "shaders should contain matmul_vec_q4_batch kernel"
7472        );
7473        assert!(
7474            shaders.contains("kernel void rms_norm_batch"),
7475            "shaders should contain rms_norm_batch kernel"
7476        );
7477        assert!(
7478            shaders.contains("kernel void silu_mul_fused_batch"),
7479            "shaders should contain silu_mul_fused_batch kernel"
7480        );
7481        assert!(
7482            shaders.contains("kernel void add_inplace_batch"),
7483            "shaders should contain add_inplace_batch kernel"
7484        );
7485        assert!(
7486            shaders.contains("kernel void copy_embedding_batch"),
7487            "shaders should contain copy_embedding_batch kernel"
7488        );
7489    }
7490
7491    #[test]
7492    fn generated_model_has_batch_pipelines() {
7493        let config = minimal_config();
7494        let model_rs = generate_model_rs(&config).unwrap();
7495
7496        for pipeline in &[
7497            "matmul_batch_pipeline",
7498            "matmul_q8_batch_pipeline",
7499            "matmul_q8_gemm_batch_pipeline",
7500            "matmul_q4_batch_pipeline",
7501            "rms_norm_batch_pipeline",
7502            "rope_batch_pipeline",
7503            "silu_mul_fused_batch_pipeline",
7504            "add_inplace_batch_pipeline",
7505            "copy_embedding_batch_pipeline",
7506            "attention_batch_pipeline",
7507            "copy_kv_batch_pipeline",
7508        ] {
7509            assert!(
7510                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
7511                "MetalModel should have {pipeline} field"
7512            );
7513        }
7514    }
7515
7516    #[test]
7517    fn generated_model_has_batch_buffers() {
7518        let config = minimal_config();
7519        let model_rs = generate_model_rs(&config).unwrap();
7520
7521        for buf in &[
7522            "batch_hidden_buf",
7523            "batch_residual_buf",
7524            "batch_qkv_buf",
7525            "batch_attn_out_buf",
7526            "batch_attn_proj_buf",
7527            "batch_gate_up_buf",
7528            "batch_ffn_hidden_buf",
7529            "batch_ffn_out_buf",
7530            "batch_tokens_buf",
7531            "batch_positions_buf",
7532        ] {
7533            assert!(
7534                model_rs.contains(&format!("{buf}: Buffer")),
7535                "MetalModel should have {buf} field"
7536            );
7537        }
7538    }
7539
7540    #[test]
7541    fn generated_model_has_forward_prefill_batch() {
7542        let config = minimal_config();
7543        let model_rs = generate_model_rs(&config).unwrap();
7544
7545        assert!(
7546            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
7547            "MetalModel should have forward_prefill_batch method"
7548        );
7549
7550        // forward_prefill should delegate to forward_prefill_batch
7551        assert!(
7552            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
7553            "forward_prefill should delegate to forward_prefill_batch"
7554        );
7555    }
7556
7557    #[test]
7558    fn generated_model_has_max_batch_size_constant() {
7559        let config = minimal_config();
7560        let model_rs = generate_model_rs(&config).unwrap();
7561
7562        assert!(
7563            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
7564            "model.rs should define MAX_BATCH_SIZE constant"
7565        );
7566    }
7567
7568    #[test]
7569    fn forward_prefill_batch_uses_batch_dispatch() {
7570        let config = minimal_config();
7571        let model_rs = generate_model_rs(&config).unwrap();
7572
7573        let batch_start = model_rs
7574            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7575            .unwrap();
7576        let batch_body = &model_rs[batch_start..];
7577        let batch_end = batch_body
7578            .find("\n    pub fn reset")
7579            .unwrap_or(batch_body.len());
7580        let batch_code = &batch_body[..batch_end];
7581
7582        // Should use batched dispatch methods
7583        assert!(
7584            batch_code.contains("dispatch_rms_norm_batch"),
7585            "forward_prefill_batch should use dispatch_rms_norm_batch"
7586        );
7587        assert!(
7588            batch_code.contains("dispatch_copy_embedding_batch"),
7589            "forward_prefill_batch should use dispatch_copy_embedding_batch"
7590        );
7591        assert!(
7592            batch_code.contains("dispatch_silu_mul_fused_batch"),
7593            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
7594        );
7595        // Should use batched causal attention dispatch
7596        assert!(
7597            batch_code.contains("dispatch_attention_batch"),
7598            "forward_prefill_batch should use dispatch_attention_batch"
7599        );
7600        // Should use fused KV cache copy (both K and V in one dispatch)
7601        assert!(
7602            batch_code.contains("dispatch_copy_kv_both_batch"),
7603            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
7604        );
7605        // Should use fused RoPE Q+K dispatch
7606        assert!(
7607            batch_code.contains("dispatch_rope_qk_batch"),
7608            "forward_prefill_batch should use dispatch_rope_qk_batch"
7609        );
7610    }
7611
7612    #[test]
7613    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
7614        let config = minimal_q8_config();
7615        let model_rs = generate_model_rs(&config).unwrap();
7616
7617        let batch_start = model_rs
7618            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7619            .unwrap();
7620        let batch_body = &model_rs[batch_start..];
7621        let batch_end = batch_body
7622            .find("\n    pub fn reset")
7623            .unwrap_or(batch_body.len());
7624        let batch_code = &batch_body[..batch_end];
7625
7626        assert!(
7627            batch_code.contains("dispatch_matmul_q8_batch"),
7628            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
7629        );
7630    }
7631
7632    #[test]
7633    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
7634        let config = minimal_q4_config();
7635        let model_rs = generate_model_rs(&config).unwrap();
7636
7637        let batch_start = model_rs
7638            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7639            .unwrap();
7640        let batch_body = &model_rs[batch_start..];
7641        let batch_end = batch_body
7642            .find("\n    pub fn reset")
7643            .unwrap_or(batch_body.len());
7644        let batch_code = &batch_body[..batch_end];
7645
7646        assert!(
7647            batch_code.contains("dispatch_matmul_q4_batch"),
7648            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
7649        );
7650    }
7651
7652    #[test]
7653    fn generated_main_rs_uses_batched_prefill() {
7654        let config = minimal_config();
7655        let main_rs = generate_main_rs("test-model", &config).unwrap();
7656
7657        assert!(
7658            main_rs.contains("forward_prefill_batch"),
7659            "main.rs should use forward_prefill_batch for prompt tokens"
7660        );
7661    }
7662
7663    #[test]
7664    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
7665        let config = minimal_config();
7666        let model_rs = generate_model_rs(&config).unwrap();
7667
7668        let batch_start = model_rs
7669            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7670            .unwrap();
7671        let batch_body = &model_rs[batch_start..];
7672        let batch_end = batch_body
7673            .find("\n    pub fn reset")
7674            .unwrap_or(batch_body.len());
7675        let batch_code = &batch_body[..batch_end];
7676
7677        assert!(
7678            batch_code.contains("dispatch_matmul_batch"),
7679            "f32 forward_prefill_batch should use dispatch_matmul_batch"
7680        );
7681        // Should NOT use Q8 or Q4 batch dispatch
7682        assert!(
7683            !batch_code.contains("dispatch_matmul_q8_batch"),
7684            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
7685        );
7686        assert!(
7687            !batch_code.contains("dispatch_matmul_q4_batch"),
7688            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
7689        );
7690    }
7691
7692    #[test]
7693    fn forward_uses_cpu_embedding_lookup() {
7694        let config = minimal_config();
7695        let model_rs = generate_model_rs(&config).unwrap();
7696
7697        // Find just the forward() body (not forward_profile)
7698        let forward_start = model_rs
7699            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7700            .unwrap();
7701        let forward_body = &model_rs[forward_start..];
7702        let forward_end = forward_body
7703            .find("\n    pub fn forward_profile")
7704            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7705            .unwrap_or(forward_body.len());
7706        let forward_code = &forward_body[..forward_end];
7707
7708        // forward() should use CPU memcpy for embedding lookup (unified memory)
7709        assert!(
7710            forward_code.contains("embed_buf.contents()"),
7711            "forward() should access embed_buf via CPU unified memory for embedding lookup"
7712        );
7713        assert!(
7714            forward_code.contains("copy_nonoverlapping"),
7715            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
7716        );
7717        // forward() should NOT use GPU dispatch for embedding
7718        assert!(
7719            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
7720            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
7721        );
7722    }
7723
7724    #[test]
7725    fn forward_profile_method_exists() {
7726        let config = minimal_config();
7727        let model_rs = generate_model_rs(&config).unwrap();
7728
7729        assert!(
7730            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
7731            "MetalModel should have forward_profile() method"
7732        );
7733        // Profile method should print timing information
7734        assert!(
7735            model_rs.contains("[profile]"),
7736            "forward_profile() should print timing with [profile] prefix"
7737        );
7738        assert!(
7739            model_rs.contains("d_embed"),
7740            "forward_profile() should measure embedding time"
7741        );
7742        assert!(
7743            model_rs.contains("d_layers"),
7744            "forward_profile() should measure layer time"
7745        );
7746        assert!(
7747            model_rs.contains("d_logits"),
7748            "forward_profile() should measure logits time"
7749        );
7750    }
7751
7752    #[test]
7753    fn generated_cli_has_profile_flag() {
7754        let config = minimal_config();
7755        let main_rs = generate_main_rs("test-model", &config).unwrap();
7756
7757        assert!(
7758            main_rs.contains("--profile"),
7759            "CLI should support --profile flag"
7760        );
7761        assert!(
7762            main_rs.contains("forward_profile"),
7763            "CLI should call forward_profile when --profile is set"
7764        );
7765    }
7766
7767    #[test]
7768    fn generated_cli_has_thermal_yield() {
7769        let config = minimal_config();
7770        let main_rs = generate_main_rs("test-model", &config).unwrap();
7771
7772        assert!(
7773            main_rs.contains("yield_now()"),
7774            "CLI generation loop should include thread::yield_now() for thermal management"
7775        );
7776    }
7777
7778    // ── Real-world validation tests ──────────────────────────────────────
7779
7780    #[test]
7781    fn generated_forward_handles_single_token_prompt() {
7782        // With a single token (the first prompt token), forward() should work
7783        // at pos=0 where seq_len=1. The attention kernel must handle the case
7784        // where there is only one KV entry (no prefill context).
7785        let config = minimal_config();
7786        let model_rs = generate_model_rs(&config).unwrap();
7787
7788        // The forward function should accept any u32 token_id (no minimum pos guard)
7789        let forward_start = model_rs
7790            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7791            .expect("forward() must exist");
7792        let forward_body = &model_rs[forward_start..forward_start + 400];
7793
7794        // Should NOT require pos > 0 or seq_len > 1
7795        assert!(
7796            !forward_body.contains("assert!(self.pos > 0"),
7797            "forward() must accept pos=0 (first token with no prefill)"
7798        );
7799
7800        // The attention kernel should handle seq_len=1 via the pos field
7801        assert!(
7802            model_rs.contains("self.pos"),
7803            "forward() should use self.pos to track sequence position"
7804        );
7805    }
7806
7807    #[test]
7808    fn generated_reset_clears_kv_cache_position() {
7809        // After reset(), the model should be in a clean state. The pos field
7810        // must be 0 so new generation starts from scratch.
7811        let config = minimal_config();
7812        let model_rs = generate_model_rs(&config).unwrap();
7813
7814        let reset_start = model_rs
7815            .find("pub fn reset(&mut self)")
7816            .expect("reset() must exist");
7817        let reset_body = &model_rs[reset_start..reset_start + 200];
7818
7819        // Reset must zero the position counter
7820        assert!(
7821            reset_body.contains("self.pos = 0"),
7822            "reset() must set self.pos = 0"
7823        );
7824
7825        // Verify reset clears prev_cmd (double-buffering state)
7826        assert!(
7827            reset_body.contains("self.prev_cmd = None"),
7828            "reset() should clear prev_cmd for clean command buffer state"
7829        );
7830    }
7831
7832    #[test]
7833    fn generated_serve_handles_empty_messages_gracefully() {
7834        // The serve endpoint should not crash when receiving an empty messages array.
7835        // The format_chat_messages function should handle this gracefully.
7836        let config = minimal_config();
7837        let main_rs = generate_main_rs("test-model", &config).unwrap();
7838
7839        // The format_chat_messages function should exist and handle empty input
7840        let format_fn_start = main_rs
7841            .find("fn format_chat_messages")
7842            .expect("format_chat_messages must exist");
7843        let format_fn_body =
7844            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
7845
7846        // It should iterate over messages (an empty slice produces an empty loop)
7847        assert!(
7848            format_fn_body.contains("for msg in messages"),
7849            "format_chat_messages should iterate over the messages slice"
7850        );
7851        // It should always append the assistant prompt suffix
7852        assert!(
7853            format_fn_body.contains("<|im_start|>assistant"),
7854            "format_chat_messages should always append assistant prompt header"
7855        );
7856
7857        // The serve function should call model.reset() before each request
7858        let serve_fn_start = main_rs
7859            .find("fn serve(")
7860            .expect("serve function must exist");
7861        let serve_fn_body = &main_rs[serve_fn_start..];
7862        assert!(
7863            serve_fn_body.contains("model.reset()"),
7864            "serve function should reset model between requests"
7865        );
7866    }
7867
7868    #[test]
7869    fn generated_model_forward_increments_pos() {
7870        // Each forward() call must increment self.pos so the next token
7871        // uses the correct RoPE position and KV cache offset.
7872        let config = minimal_config();
7873        let model_rs = generate_model_rs(&config).unwrap();
7874
7875        let forward_start = model_rs
7876            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7877            .unwrap();
7878        let forward_body = &model_rs[forward_start..];
7879        let forward_end = forward_body
7880            .find("\n    pub fn forward_profile")
7881            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7882            .or_else(|| forward_body.find("\n    fn dispatch_"))
7883            .unwrap_or(forward_body.len());
7884        let forward_code = &forward_body[..forward_end];
7885
7886        assert!(
7887            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
7888            "forward() must increment self.pos after processing a token"
7889        );
7890    }
7891}