Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    r#"//
122// Auto-generated by ForgeLLM Metal codegen.
123// Metal Shading Language compute kernels for transformer inference.
124//
125// Optimized with simdgroup cooperative reductions, shared memory vector
126// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
127// lane, and fast:: math intrinsics for Apple Silicon throughput.
128//
129
130#include <metal_stdlib>
131#include <metal_simdgroup_matrix>
132using namespace metal;
133
134// ── Constants ───────────────────────────────────────────────────────────
135// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
136// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
137// per shared memory load vs 4-row, improving ILP and reducing launches.
138constant constexpr uint SIMDGROUPS_PER_TG = 8;
139constant constexpr uint ROWS_PER_SIMDGROUP = 8;
140constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
141
142// ── matmul_vec ──────────────────────────────────────────────────────────
143// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
144// Uses simdgroup cooperative dot product with shared memory vector caching
145// and float4 vectorized loads. Each simdgroup processes 8 rows for better
146// shared memory reuse (8x vector reuse per load) and instruction-level
147// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
148kernel void matmul_vec(
149    device const float* matrix [[buffer(0)]],
150    device const float* vector [[buffer(1)]],
151    device float* output       [[buffer(2)]],
152    constant uint& rows        [[buffer(3)]],
153    constant uint& cols        [[buffer(4)]],
154    uint tgid [[threadgroup_position_in_grid]],
155    uint tid [[thread_index_in_threadgroup]],
156    uint simd_lane [[thread_index_in_simdgroup]],
157    uint simd_id [[simdgroup_index_in_threadgroup]])
158{
159    // Cooperatively load vector into threadgroup shared memory
160    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
161    for (uint i = tid; i < cols; i += 256) {
162        vec_tile[i] = vector[i];
163    }
164    threadgroup_barrier(mem_flags::mem_threadgroup);
165
166    // Each simdgroup handles 8 consecutive rows
167    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
168    if (row_base >= rows) return;
169
170    uint base0 = row_base * cols;
171    uint base1 = (row_base + 1) * cols;
172    uint base2 = (row_base + 2) * cols;
173    uint base3 = (row_base + 3) * cols;
174    uint base4 = (row_base + 4) * cols;
175    uint base5 = (row_base + 5) * cols;
176    uint base6 = (row_base + 6) * cols;
177    uint base7 = (row_base + 7) * cols;
178
179    // float4 vectorized accumulation across 8 rows
180    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
181    float4 sum4_0 = float4(0.0f);
182    float4 sum4_1 = float4(0.0f);
183    float4 sum4_2 = float4(0.0f);
184    float4 sum4_3 = float4(0.0f);
185    float4 sum4_4 = float4(0.0f);
186    float4 sum4_5 = float4(0.0f);
187    float4 sum4_6 = float4(0.0f);
188    float4 sum4_7 = float4(0.0f);
189
190    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
191        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
192        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
193        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
194        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
195        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
196        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
197        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
198        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
199        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
200    }
201
202    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
203    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
204    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
205    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
206    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
207    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
208    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
209    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
210
211    // Handle remaining elements (cols not divisible by 128)
212    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
213        float vv = vec_tile[j];
214        sum0 += matrix[base0 + j] * vv;
215        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
216        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
217        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
218        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
219        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
220        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
221        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
222    }
223
224    // Simdgroup hardware warp-level reduction
225    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
226    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
227    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
228    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
229
230    // Only first lane writes the results
231    if (simd_lane == 0) {
232        if (row_base     < rows) output[row_base]     = sum0;
233        if (row_base + 1 < rows) output[row_base + 1] = sum1;
234        if (row_base + 2 < rows) output[row_base + 2] = sum2;
235        if (row_base + 3 < rows) output[row_base + 3] = sum3;
236        if (row_base + 4 < rows) output[row_base + 4] = sum4;
237        if (row_base + 5 < rows) output[row_base + 5] = sum5;
238        if (row_base + 6 < rows) output[row_base + 6] = sum6;
239        if (row_base + 7 < rows) output[row_base + 7] = sum7;
240    }
241}
242
243// ── rms_norm ────────────────────────────────────────────────────────────
244// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
245// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
246// via shared memory for minimal synchronization overhead.
247kernel void rms_norm(
248    device const float* input   [[buffer(0)]],
249    device const float* weight  [[buffer(1)]],
250    device float* output        [[buffer(2)]],
251    constant uint& n            [[buffer(3)]],
252    constant float& eps         [[buffer(4)]],
253    uint tid [[thread_index_in_threadgroup]])
254{
255    // Each thread accumulates partial sum-of-squares
256    float sum_sq = 0.0f;
257    for (uint i = tid; i < n; i += 256) {
258        float v = input[i];
259        sum_sq += v * v;
260    }
261
262    // Simdgroup-level reduction (hardware warp sum)
263    sum_sq = simd_sum(sum_sq);
264
265    // Cross-simdgroup reduction via shared memory
266    threadgroup float shared[8];
267    uint simd_id = tid / 32;
268    uint simd_lane = tid % 32;
269    if (simd_lane == 0) {
270        shared[simd_id] = sum_sq;
271    }
272    threadgroup_barrier(mem_flags::mem_threadgroup);
273
274    // First thread computes final inverse RMS
275    if (tid == 0) {
276        float total = 0.0f;
277        for (uint i = 0; i < 8; i++) {
278            total += shared[i];
279        }
280        shared[0] = fast::rsqrt(total / float(n) + eps);
281    }
282    threadgroup_barrier(mem_flags::mem_threadgroup);
283
284    float inv_rms = shared[0];
285
286    // Normalize
287    for (uint i = tid; i < n; i += 256) {
288        output[i] = input[i] * inv_rms * weight[i];
289    }
290}
291
292// ── rope ────────────────────────────────────────────────────────────────
293// Rotary Position Embedding applied in-place.
294// Each thread handles one (head, pair) combination.
295kernel void rope(
296    device float* data        [[buffer(0)]],
297    constant uint& num_heads  [[buffer(1)]],
298    constant uint& head_dim   [[buffer(2)]],
299    constant uint& pos        [[buffer(3)]],
300    constant float& theta     [[buffer(4)]],
301    uint id [[thread_position_in_grid]])
302{
303    uint half_dim = head_dim / 2;
304    uint total_pairs = num_heads * half_dim;
305    if (id >= total_pairs) return;
306
307    uint h = id / half_dim;
308    uint i = id % half_dim;
309    uint off = h * head_dim;
310
311    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
312    float angle = float(pos) * freq;
313    float c = cos(angle);
314    float s = sin(angle);
315
316    float x0 = data[off + 2 * i];
317    float x1 = data[off + 2 * i + 1];
318    data[off + 2 * i]     = x0 * c - x1 * s;
319    data[off + 2 * i + 1] = x0 * s + x1 * c;
320}
321
322// ── softmax ─────────────────────────────────────────────────────────────
323// Numerically stable softmax over a 1-D array.
324// Single-threadgroup kernel with cooperative reduction.
325kernel void softmax(
326    device float* data       [[buffer(0)]],
327    constant uint& n         [[buffer(1)]],
328    uint tid [[thread_index_in_threadgroup]],
329    uint tg_size [[threads_per_threadgroup]])
330{
331    threadgroup float shared_val[256];
332
333    // Pass 1: find max
334    float local_max = -INFINITY;
335    for (uint i = tid; i < n; i += tg_size) {
336        local_max = max(local_max, data[i]);
337    }
338    shared_val[tid] = local_max;
339    threadgroup_barrier(mem_flags::mem_threadgroup);
340
341    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
342        if (tid < stride) {
343            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
344        }
345        threadgroup_barrier(mem_flags::mem_threadgroup);
346    }
347    float max_val = shared_val[0];
348    threadgroup_barrier(mem_flags::mem_threadgroup);
349
350    // Pass 2: exp and sum
351    float local_sum = 0.0f;
352    for (uint i = tid; i < n; i += tg_size) {
353        float e = fast::exp(data[i] - max_val);
354        data[i] = e;
355        local_sum += e;
356    }
357    shared_val[tid] = local_sum;
358    threadgroup_barrier(mem_flags::mem_threadgroup);
359
360    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
361        if (tid < stride) {
362            shared_val[tid] += shared_val[tid + stride];
363        }
364        threadgroup_barrier(mem_flags::mem_threadgroup);
365    }
366    float inv_sum = 1.0f / shared_val[0];
367    threadgroup_barrier(mem_flags::mem_threadgroup);
368
369    // Pass 3: normalize
370    for (uint i = tid; i < n; i += tg_size) {
371        data[i] *= inv_sum;
372    }
373}
374
375// ── silu_mul ────────────────────────────────────────────────────────────
376// Fused SiLU activation * element-wise multiply:
377//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
378kernel void silu_mul(
379    device const float* gate [[buffer(0)]],
380    device const float* up   [[buffer(1)]],
381    device float* output     [[buffer(2)]],
382    constant uint& n         [[buffer(3)]],
383    uint id [[thread_position_in_grid]])
384{
385    if (id >= n) return;
386    float g = gate[id];
387    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
388}
389
390// ── silu_mul_fused ─────────────────────────────────────────────────────
391// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
392//   gate = gate_up[0..n], up = gate_up[n..2*n]
393//   output[i] = silu(gate_up[i]) * gate_up[n + i]
394kernel void silu_mul_fused(
395    device const float* gate_up [[buffer(0)]],
396    device float* output        [[buffer(1)]],
397    constant uint& n            [[buffer(2)]],
398    uint id [[thread_position_in_grid]])
399{
400    if (id >= n) return;
401    float g = gate_up[id];
402    float u = gate_up[n + id];
403    output[id] = (g / (1.0f + fast::exp(-g))) * u;
404}
405
406// ── elementwise_add ─────────────────────────────────────────────────────
407// Residual connection: output[i] = a[i] + b[i]
408kernel void elementwise_add(
409    device const float* a  [[buffer(0)]],
410    device const float* b  [[buffer(1)]],
411    device float* output   [[buffer(2)]],
412    constant uint& n       [[buffer(3)]],
413    uint id [[thread_position_in_grid]])
414{
415    if (id >= n) return;
416    output[id] = a[id] + b[id];
417}
418
419// ── copy_buffer ─────────────────────────────────────────────────────────
420// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
421// transitions. Used for KV cache updates and embedding lookup.
422kernel void copy_buffer(
423    device const float* src [[buffer(0)]],
424    device float* dst       [[buffer(1)]],
425    constant uint& count    [[buffer(2)]],
426    uint id [[thread_position_in_grid]])
427{
428    if (id < count) dst[id] = src[id];
429}
430
431// ── copy_offset ─────────────────────────────────────────────────────────
432// Copy with source offset (in floats). Used for embedding table lookup
433// where we need to copy a specific row from a large table.
434kernel void copy_offset(
435    device const float* src     [[buffer(0)]],
436    device float* dst           [[buffer(1)]],
437    constant uint& src_offset   [[buffer(2)]],  // in floats
438    constant uint& count        [[buffer(3)]],
439    uint id [[thread_position_in_grid]])
440{
441    if (id < count) dst[id] = src[src_offset + id];
442}
443
444// ── add_inplace ─────────────────────────────────────────────────────────
445// In-place residual connection: a[i] += b[i]
446// Avoids a separate blit copy for residual add, reducing encoder overhead.
447kernel void add_inplace(
448    device float* a        [[buffer(0)]],
449    device const float* b  [[buffer(1)]],
450    constant uint& n       [[buffer(2)]],
451    uint id [[thread_position_in_grid]])
452{
453    if (id >= n) return;
454    a[id] += b[id];
455}
456
457// ── matmul_vec_q8 ─────────────────────────────────────────────────────
458// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
459// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
460// Operates directly on quantized weights to halve memory bandwidth vs f32,
461// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
462//
463// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
464// because int8->float conversion doubles register demand.  Fully unrolled
465// inner loop with float4 vector loads from shared memory eliminates loop
466// overhead and enables better instruction scheduling.
467// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
468constant constexpr uint Q8_ROWS_PER_SG = 4;
469constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
470
471// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
472// doubles ALU work, so the same register budget applies).
473constant constexpr uint Q4_ROWS_PER_SG = 4;
474constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
475
476kernel void matmul_vec_q8(
477    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
478    device const float* vector   [[buffer(1)]],  // f32 input
479    device float* output         [[buffer(2)]],
480    constant uint& rows          [[buffer(3)]],
481    constant uint& cols          [[buffer(4)]],  // number of elements per row
482    uint tgid [[threadgroup_position_in_grid]],
483    uint tid [[thread_index_in_threadgroup]],
484    uint simd_lane [[thread_index_in_simdgroup]],
485    uint simd_id [[simdgroup_index_in_threadgroup]])
486{
487    // Load vector into shared memory
488    threadgroup float vec_tile[VEC_TILE_SIZE];
489    for (uint i = tid; i < cols; i += 256) {
490        vec_tile[i] = vector[i];
491    }
492    threadgroup_barrier(mem_flags::mem_threadgroup);
493
494    // Each simdgroup handles 4 consecutive rows (lower register pressure)
495    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
496    if (row_base >= rows) return;
497
498    // Q8_0: each block is 34 bytes for 32 elements
499    uint blocks_per_row = cols / 32;
500    uint row_bytes = blocks_per_row * 34;
501
502    // Pointers to each row's Q8_0 data
503    device const uchar* r0 = matrix + row_base * row_bytes;
504    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
505    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
506    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
507
508    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
509
510    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
511        uint bb = blk * 34;
512        uint vb = blk * 32;
513
514        // Prefetch all 4 scales
515        float sc0 = float(*(device const half*)(r0 + bb));
516        float sc1 = float(*(device const half*)(r1 + bb));
517        float sc2 = float(*(device const half*)(r2 + bb));
518        float sc3 = float(*(device const half*)(r3 + bb));
519
520        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
521        // Q8_0 block layout where the int8 data starts at offset +2 from a
522        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
523        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
524        // reduction in memory transactions. Metal's char16/packed_char16 are
525        // reserved types and packed_*int4 require >=4-byte alignment which
526        // this layout does not provide, so packed_short4 is the widest valid
527        // vectorized load.
528        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
529        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
530        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
531        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
532
533        // Load all 8 float4 vector values for this 32-element block from shared memory
534        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
535        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
536        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
537        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
538        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
539        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
540        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
541        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
542
543        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
544        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
545        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
546            short4 _s = short4(SHORT4); \
547            char2 _a = as_type<char2>(_s.x); \
548            char2 _b = as_type<char2>(_s.y); \
549            char2 _c = as_type<char2>(_s.z); \
550            char2 _d = as_type<char2>(_s.w); \
551            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
552            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
553        }
554
555        float4 f0, f1;
556        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
557
558        // Row 0: 4 short4 loads cover 32 int8 weights
559        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
560        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
561        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
562        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
563
564        // Row 1
565        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
566        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
567        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
568        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
569
570        // Row 2
571        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
572        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
573        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
574        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
575
576        // Row 3
577        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
578        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
579        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
580        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
581
582        #undef Q8_UNPACK8
583
584        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
585    }
586
587    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
588    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
589
590    if (simd_lane == 0) {
591        if (row_base     < rows) output[row_base]     = sum0;
592        if (row_base + 1 < rows) output[row_base + 1] = sum1;
593        if (row_base + 2 < rows) output[row_base + 2] = sum2;
594        if (row_base + 3 < rows) output[row_base + 3] = sum3;
595    }
596}
597
598// ── matmul_vec_q4 ─────────────────────────────────────────────────────
599// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
600// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
601// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
602// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
603//
604// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
605// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
606kernel void matmul_vec_q4(
607    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
608    device const float* vector   [[buffer(1)]],  // f32 input
609    device float* output         [[buffer(2)]],
610    constant uint& rows          [[buffer(3)]],
611    constant uint& cols          [[buffer(4)]],  // number of elements per row
612    uint tgid [[threadgroup_position_in_grid]],
613    uint tid [[thread_index_in_threadgroup]],
614    uint simd_lane [[thread_index_in_simdgroup]],
615    uint simd_id [[simdgroup_index_in_threadgroup]])
616{
617    // Load vector into shared memory
618    threadgroup float vec_tile[VEC_TILE_SIZE];
619    for (uint i = tid; i < cols; i += 256) {
620        vec_tile[i] = vector[i];
621    }
622    threadgroup_barrier(mem_flags::mem_threadgroup);
623
624    // Each simdgroup handles 4 consecutive rows
625    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
626    if (row_base >= rows) return;
627
628    // Q4_0: each block is 18 bytes for 32 elements
629    uint blocks_per_row = cols / 32;
630    uint row_bytes = blocks_per_row * 18;
631
632    // Pointers to each row's Q4_0 data
633    device const uchar* r0 = matrix + row_base * row_bytes;
634    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
635    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
636    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
637
638    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
639
640    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
641        uint bb = blk * 18;
642        uint vb = blk * 32;
643
644        // Prefetch all 4 scales
645        float sc0 = float(*(device const half*)(r0 + bb));
646        float sc1 = float(*(device const half*)(r1 + bb));
647        float sc2 = float(*(device const half*)(r2 + bb));
648        float sc3 = float(*(device const half*)(r3 + bb));
649
650        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
651        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
652        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
653        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
654        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
655
656        // Load 8 float4 vector values for 32 elements from shared memory
657        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
658        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
659        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
660        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
661        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
662        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
663        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
664        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
665        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
666
667        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
668        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
669        float bd0=0, bd1=0, bd2=0, bd3=0;
670        uchar4 b;
671
672        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
673        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
674                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
675                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
676                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
677        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
678                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
679                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
680                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
681        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
682                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
683                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
684                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
685        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
686                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
687                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
688                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
689
690        // Row 1
691        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
692                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
693                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
694                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
695        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
696                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
697                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
698                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
699        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
700                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
701                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
702                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
703        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
704                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
705                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
706                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
707
708        // Row 2
709        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
710                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
711                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
712                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
713        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
714                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
715                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
716                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
717        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
718                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
719                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
720                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
721        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
722                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
723                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
724                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
725
726        // Row 3
727        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
728                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
729                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
730                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
731        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
732                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
733                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
734                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
735        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
736                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
737                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
738                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
739        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
740                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
741                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
742                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
743
744        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
745    }
746
747    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
748    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
749
750    if (simd_lane == 0) {
751        if (row_base     < rows) output[row_base]     = sum0;
752        if (row_base + 1 < rows) output[row_base + 1] = sum1;
753        if (row_base + 2 < rows) output[row_base + 2] = sum2;
754        if (row_base + 3 < rows) output[row_base + 3] = sum3;
755    }
756}
757
758// ── attention ───────────────────────────────────────────────────────────
759// Single-query attention with simdgroup cooperative reductions.
760// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
761// with simd_max/simd_sum reductions, then weighted sum of V.
762// Each threadgroup handles one head with 256 threads (8 simdgroups).
763//
764// Buffers:
765//   q:       [num_heads * head_dim]       current query
766//   k_cache: [max_seq_len * num_kv_heads * head_dim]
767//   v_cache: [max_seq_len * num_kv_heads * head_dim]
768//   output:  [num_heads * head_dim]
769kernel void attention(
770    device const float* q        [[buffer(0)]],
771    device const float* k_cache  [[buffer(1)]],
772    device const float* v_cache  [[buffer(2)]],
773    device float* output         [[buffer(3)]],
774    constant uint& seq_len       [[buffer(4)]],
775    constant uint& num_heads     [[buffer(5)]],
776    constant uint& num_kv_heads  [[buffer(6)]],
777    constant uint& head_dim      [[buffer(7)]],
778    uint tgid [[threadgroup_position_in_grid]],
779    uint tid [[thread_index_in_threadgroup]],
780    uint simd_lane [[thread_index_in_simdgroup]],
781    uint simd_id [[simdgroup_index_in_threadgroup]])
782{
783    uint head = tgid;
784    if (head >= num_heads) return;
785    uint kv_head = head / (num_heads / num_kv_heads);
786
787    uint q_off = head * head_dim;
788
789    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
790    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
791    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
792    // most generation steps have short effective context.
793    threadgroup float scores[2048];  // max seq_len for generation phase
794
795    for (uint s = simd_id; s < seq_len; s += 8) {
796        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
797        float dot = 0.0;
798        for (uint d = simd_lane; d < head_dim; d += 32) {
799            dot += q[q_off + d] * k_cache[k_off + d];
800        }
801        dot = simd_sum(dot);
802        if (simd_lane == 0) {
803            scores[s] = dot * fast::rsqrt(float(head_dim));
804        }
805    }
806    threadgroup_barrier(mem_flags::mem_threadgroup);
807
808    // Step 2: Softmax over scores (cooperative)
809    // Find max
810    float local_max = -INFINITY;
811    for (uint s = tid; s < seq_len; s += 256) {
812        local_max = max(local_max, scores[s]);
813    }
814    local_max = simd_max(local_max);
815    threadgroup float shared_max[8];
816    if (simd_lane == 0) shared_max[simd_id] = local_max;
817    threadgroup_barrier(mem_flags::mem_threadgroup);
818    if (tid == 0) {
819        float m = shared_max[0];
820        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
821        shared_max[0] = m;
822    }
823    threadgroup_barrier(mem_flags::mem_threadgroup);
824    float max_val = shared_max[0];
825
826    // Exp and sum
827    float local_sum = 0.0;
828    for (uint s = tid; s < seq_len; s += 256) {
829        scores[s] = fast::exp(scores[s] - max_val);
830        local_sum += scores[s];
831    }
832    local_sum = simd_sum(local_sum);
833    threadgroup float shared_sum[8];
834    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
835    threadgroup_barrier(mem_flags::mem_threadgroup);
836    if (tid == 0) {
837        float total = 0.0;
838        for (uint i = 0; i < 8; i++) total += shared_sum[i];
839        shared_sum[0] = 1.0 / total;
840    }
841    threadgroup_barrier(mem_flags::mem_threadgroup);
842    float inv_sum = shared_sum[0];
843
844    for (uint s = tid; s < seq_len; s += 256) {
845        scores[s] *= inv_sum;
846    }
847    threadgroup_barrier(mem_flags::mem_threadgroup);
848
849    // Step 3: Weighted sum of V: output = scores · V
850    // Each thread handles a range of head_dim dimensions.
851    // Process 4 sequence positions at a time for better ILP and reduced
852    // loop overhead (float4 score gather, 4 V loads per iteration).
853    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
854    uint v_stride = num_kv_heads * head_dim;
855    for (uint d = tid; d < head_dim; d += 256) {
856        float acc = 0.0;
857        uint v_base = kv_head * head_dim + d;
858        for (uint s = 0; s < seq_len4; s += 4) {
859            float sc0 = scores[s];
860            float sc1 = scores[s + 1];
861            float sc2 = scores[s + 2];
862            float sc3 = scores[s + 3];
863            acc += sc0 * v_cache[s * v_stride + v_base]
864                 + sc1 * v_cache[(s+1) * v_stride + v_base]
865                 + sc2 * v_cache[(s+2) * v_stride + v_base]
866                 + sc3 * v_cache[(s+3) * v_stride + v_base];
867        }
868        for (uint s = seq_len4; s < seq_len; s++) {
869            acc += scores[s] * v_cache[s * v_stride + v_base];
870        }
871        output[q_off + d] = acc;
872    }
873}
874
875// ── Batched prefill kernels ────────────────────────────────────────────
876// These kernels process M input vectors against the same weight matrix
877// in a single dispatch, converting mat-vec into mat-mat for better GPU
878// utilization during prompt prefill.
879
880// ── rms_norm_batch ─────────────────────────────────────────────────────
881// RMS normalization for a batch of vectors.
882// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
883// Grid: M threadgroups (one per token).
884kernel void rms_norm_batch(
885    device const float* input   [[buffer(0)]],  // [M, n]
886    device const float* weight  [[buffer(1)]],  // [n]
887    device float* output        [[buffer(2)]],  // [M, n]
888    constant uint& n            [[buffer(3)]],
889    constant float& eps         [[buffer(4)]],
890    constant uint& num_tokens   [[buffer(5)]],
891    uint tgid [[threadgroup_position_in_grid]],
892    uint tid [[thread_index_in_threadgroup]])
893{
894    if (tgid >= num_tokens) return;
895
896    uint base = tgid * n;
897
898    float sum_sq = 0.0f;
899    for (uint i = tid; i < n; i += 256) {
900        float v = input[base + i];
901        sum_sq += v * v;
902    }
903
904    sum_sq = simd_sum(sum_sq);
905
906    threadgroup float shared[8];
907    uint simd_id = tid / 32;
908    uint simd_lane = tid % 32;
909    if (simd_lane == 0) {
910        shared[simd_id] = sum_sq;
911    }
912    threadgroup_barrier(mem_flags::mem_threadgroup);
913
914    if (tid == 0) {
915        float total = 0.0f;
916        for (uint i = 0; i < 8; i++) {
917            total += shared[i];
918        }
919        shared[0] = fast::rsqrt(total / float(n) + eps);
920    }
921    threadgroup_barrier(mem_flags::mem_threadgroup);
922
923    float inv_rms = shared[0];
924
925    for (uint i = tid; i < n; i += 256) {
926        output[base + i] = input[base + i] * inv_rms * weight[i];
927    }
928}
929
930// ── rope_batch ─────────────────────────────────────────────────────────
931// Rotary Position Embedding for a batch of vectors with different positions.
932// data layout: [M, num_heads * head_dim], positions: [M]
933// Each thread handles one (token, head, pair) combination.
934kernel void rope_batch(
935    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
936    constant uint& num_heads     [[buffer(1)]],
937    constant uint& head_dim      [[buffer(2)]],
938    device const uint* positions  [[buffer(3)]],  // [M] position per token
939    constant float& theta        [[buffer(4)]],
940    constant uint& num_tokens    [[buffer(5)]],
941    uint id [[thread_position_in_grid]])
942{
943    uint half_dim = head_dim / 2;
944    uint pairs_per_token = num_heads * half_dim;
945    uint total = num_tokens * pairs_per_token;
946    if (id >= total) return;
947
948    uint token = id / pairs_per_token;
949    uint rem = id % pairs_per_token;
950    uint h = rem / half_dim;
951    uint i = rem % half_dim;
952    uint off = token * (num_heads * head_dim) + h * head_dim;
953
954    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
955    float angle = float(positions[token]) * freq;
956    float c = cos(angle);
957    float s = sin(angle);
958
959    float x0 = data[off + 2 * i];
960    float x1 = data[off + 2 * i + 1];
961    data[off + 2 * i]     = x0 * c - x1 * s;
962    data[off + 2 * i + 1] = x0 * s + x1 * c;
963}
964
965// ── silu_mul_fused_batch ───────────────────────────────────────────────
966// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
967// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
968kernel void silu_mul_fused_batch(
969    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
970    device float* output        [[buffer(1)]],  // [M, n]
971    constant uint& n            [[buffer(2)]],
972    constant uint& num_tokens   [[buffer(3)]],
973    uint id [[thread_position_in_grid]])
974{
975    uint total = num_tokens * n;
976    if (id >= total) return;
977    uint token = id / n;
978    uint i = id % n;
979    uint gu_base = token * 2 * n;
980    float g = gate_up[gu_base + i];
981    float u = gate_up[gu_base + n + i];
982    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
983}
984
985// ── add_inplace_batch ──────────────────────────────────────────────────
986// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
987kernel void add_inplace_batch(
988    device float* a        [[buffer(0)]],  // [M * n]
989    device const float* b  [[buffer(1)]],  // [M * n]
990    constant uint& total   [[buffer(2)]],  // M * n
991    uint id [[thread_position_in_grid]])
992{
993    if (id >= total) return;
994    a[id] += b[id];
995}
996
997// ── copy_embedding_batch ───────────────────────────────────────────────
998// Copy M embedding rows from embedding table to a contiguous batch buffer.
999// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1000kernel void copy_embedding_batch(
1001    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1002    device float* output        [[buffer(1)]],  // [M, dim]
1003    device const uint* tokens   [[buffer(2)]],  // [M]
1004    constant uint& dim          [[buffer(3)]],
1005    constant uint& num_tokens   [[buffer(4)]],
1006    uint id [[thread_position_in_grid]])
1007{
1008    uint total = num_tokens * dim;
1009    if (id >= total) return;
1010    uint token_idx = id / dim;
1011    uint d = id % dim;
1012    output[id] = embed[tokens[token_idx] * dim + d];
1013}
1014
1015// ── matmul_vec_batch ───────────────────────────────────────────────────
1016// Batched matrix-vector multiply: process M input vectors against the same
1017// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1018// Each threadgroup handles one (token, row_group) pair.
1019kernel void matmul_vec_batch(
1020    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1021    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1022    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1023    constant uint& num_tokens   [[buffer(3)]],  // M
1024    constant uint& rows         [[buffer(4)]],
1025    constant uint& cols         [[buffer(5)]],
1026    uint tgid [[threadgroup_position_in_grid]],
1027    uint tid [[thread_index_in_threadgroup]],
1028    uint simd_lane [[thread_index_in_simdgroup]],
1029    uint simd_id [[simdgroup_index_in_threadgroup]])
1030{
1031    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1032    uint token = tgid / row_tgs;
1033    uint tg_in_token = tgid % row_tgs;
1034    if (token >= num_tokens) return;
1035
1036    // Load this token's input vector into shared memory
1037    threadgroup float vec_tile[VEC_TILE_SIZE];
1038    device const float* input = inputs + token * cols;
1039    for (uint i = tid; i < cols; i += 256) {
1040        vec_tile[i] = input[i];
1041    }
1042    threadgroup_barrier(mem_flags::mem_threadgroup);
1043
1044    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1045    if (row_base >= rows) return;
1046
1047    uint base0 = row_base * cols;
1048    uint base1 = (row_base + 1) * cols;
1049    uint base2 = (row_base + 2) * cols;
1050    uint base3 = (row_base + 3) * cols;
1051    uint base4 = (row_base + 4) * cols;
1052    uint base5 = (row_base + 5) * cols;
1053    uint base6 = (row_base + 6) * cols;
1054    uint base7 = (row_base + 7) * cols;
1055
1056    uint cols_vec4 = cols & ~127u;
1057    float4 sum4_0 = float4(0.0f);
1058    float4 sum4_1 = float4(0.0f);
1059    float4 sum4_2 = float4(0.0f);
1060    float4 sum4_3 = float4(0.0f);
1061    float4 sum4_4 = float4(0.0f);
1062    float4 sum4_5 = float4(0.0f);
1063    float4 sum4_6 = float4(0.0f);
1064    float4 sum4_7 = float4(0.0f);
1065
1066    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1067        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1068        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1069        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1070        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1071        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1072        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1073        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1074        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1075        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1076    }
1077
1078    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1079    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1080    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1081    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1082    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1083    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1084    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1085    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1086
1087    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1088        float vv = vec_tile[j];
1089        sum0 += matrix[base0 + j] * vv;
1090        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1091        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1092        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1093        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1094        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1095        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1096        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1097    }
1098
1099    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1100    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1101    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1102    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1103
1104    device float* output = outputs + token * rows;
1105    if (simd_lane == 0) {
1106        if (row_base     < rows) output[row_base]     = sum0;
1107        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1108        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1109        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1110        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1111        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1112        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1113        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1114    }
1115}
1116
1117// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1118// Batched Q8_0 matrix-vector multiply for M input vectors.
1119// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1120kernel void matmul_vec_q8_batch(
1121    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1122    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1123    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1124    constant uint& num_tokens    [[buffer(3)]],  // M
1125    constant uint& rows          [[buffer(4)]],
1126    constant uint& cols          [[buffer(5)]],
1127    uint tgid [[threadgroup_position_in_grid]],
1128    uint tid [[thread_index_in_threadgroup]],
1129    uint simd_lane [[thread_index_in_simdgroup]],
1130    uint simd_id [[simdgroup_index_in_threadgroup]])
1131{
1132    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1133    uint token = tgid / row_tgs;
1134    uint tg_in_token = tgid % row_tgs;
1135    if (token >= num_tokens) return;
1136
1137    // Load this token's input vector into shared memory
1138    threadgroup float vec_tile[VEC_TILE_SIZE];
1139    device const float* input = inputs + token * cols;
1140    for (uint i = tid; i < cols; i += 256) {
1141        vec_tile[i] = input[i];
1142    }
1143    threadgroup_barrier(mem_flags::mem_threadgroup);
1144
1145    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1146    if (row_base >= rows) return;
1147
1148    uint blocks_per_row = cols / 32;
1149    uint row_bytes = blocks_per_row * 34;
1150
1151    device const uchar* r0 = matrix + row_base * row_bytes;
1152    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1153    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1154    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1155
1156    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1157
1158    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1159        uint bb = blk * 34;
1160        uint vb = blk * 32;
1161
1162        float sc0 = float(*(device const half*)(r0 + bb));
1163        float sc1 = float(*(device const half*)(r1 + bb));
1164        float sc2 = float(*(device const half*)(r2 + bb));
1165        float sc3 = float(*(device const half*)(r3 + bb));
1166
1167        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1168        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1169        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1170        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1171        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1172        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1173
1174        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1175        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1176        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1177        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1178        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1179        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1180        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1181        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1182
1183        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1184            short4 _s = short4(SHORT4); \
1185            char2 _a = as_type<char2>(_s.x); \
1186            char2 _b = as_type<char2>(_s.y); \
1187            char2 _c = as_type<char2>(_s.z); \
1188            char2 _d = as_type<char2>(_s.w); \
1189            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1190            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1191        }
1192
1193        float4 f0, f1;
1194        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1195
1196        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1197        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1198        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1199        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1200
1201        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1202        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1203        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1204        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1205
1206        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1207        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1208        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1209        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1210
1211        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1212        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1213        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1214        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1215
1216        #undef Q8_UNPACK8
1217
1218        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1219    }
1220
1221    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1222    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1223
1224    device float* output = outputs + token * rows;
1225    if (simd_lane == 0) {
1226        if (row_base     < rows) output[row_base]     = sum0;
1227        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1228        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1229        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1230    }
1231}
1232
1233// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1234// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1235// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1236// the Q8_0 weight blocks are fetched once from device memory and reused for
1237// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1238// per-token dispatch).
1239//
1240// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1241// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1242// with simd_sum.  Token vectors are read directly from device memory inside
1243// the block loop (not cached in shared memory) so intermediate_size up to
1244// 8192 fits without spilling threadgroup memory.
1245constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1246
1247kernel void matmul_q8_gemm_batch(
1248    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1249    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1250    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1251    constant uint& num_tokens    [[buffer(3)]],  // M
1252    constant uint& rows          [[buffer(4)]],
1253    constant uint& cols          [[buffer(5)]],
1254    uint2 tgid [[threadgroup_position_in_grid]],
1255    uint tid [[thread_index_in_threadgroup]],
1256    uint simd_lane [[thread_index_in_simdgroup]],
1257    uint simd_id [[simdgroup_index_in_threadgroup]])
1258{
1259    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1260    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1261    if (row_base >= rows || tok_base >= num_tokens) return;
1262
1263    // How many tokens in this tile are valid?
1264    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1265
1266    uint blocks_per_row = cols / 32;
1267    uint row_bytes = blocks_per_row * 34;
1268
1269    device const uchar* r0 = matrix + row_base * row_bytes;
1270    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1271    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1272    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1273
1274    // Accumulators: 4 tokens × 4 rows per simdgroup.
1275    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1276    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1277    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1278    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1279
1280    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1281        uint bb = blk * 34;
1282        uint vb = blk * 32;
1283
1284        // ── Load weight data ONCE per block (reused across all tokens) ──
1285        float sc0 = float(*(device const half*)(r0 + bb));
1286        float sc1 = float(*(device const half*)(r1 + bb));
1287        float sc2 = float(*(device const half*)(r2 + bb));
1288        float sc3 = float(*(device const half*)(r3 + bb));
1289
1290        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1291        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1292        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1293        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1294
1295        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1296            short4 _s = short4(SHORT4); \
1297            char2 _a = as_type<char2>(_s.x); \
1298            char2 _b = as_type<char2>(_s.y); \
1299            char2 _c = as_type<char2>(_s.z); \
1300            char2 _d = as_type<char2>(_s.w); \
1301            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1302            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1303        }
1304
1305        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1306        // registers for the duration of the block and are dotted against
1307        // every token's vector tile.
1308        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1309        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1310        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1311        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1312
1313        Q8_UNPACK8(d0[0], w0_0, w0_1);
1314        Q8_UNPACK8(d0[1], w0_2, w0_3);
1315        Q8_UNPACK8(d0[2], w0_4, w0_5);
1316        Q8_UNPACK8(d0[3], w0_6, w0_7);
1317
1318        Q8_UNPACK8(d1[0], w1_0, w1_1);
1319        Q8_UNPACK8(d1[1], w1_2, w1_3);
1320        Q8_UNPACK8(d1[2], w1_4, w1_5);
1321        Q8_UNPACK8(d1[3], w1_6, w1_7);
1322
1323        Q8_UNPACK8(d2[0], w2_0, w2_1);
1324        Q8_UNPACK8(d2[1], w2_2, w2_3);
1325        Q8_UNPACK8(d2[2], w2_4, w2_5);
1326        Q8_UNPACK8(d2[3], w2_6, w2_7);
1327
1328        Q8_UNPACK8(d3[0], w3_0, w3_1);
1329        Q8_UNPACK8(d3[1], w3_2, w3_3);
1330        Q8_UNPACK8(d3[2], w3_4, w3_5);
1331        Q8_UNPACK8(d3[3], w3_6, w3_7);
1332
1333        #undef Q8_UNPACK8
1334
1335        // ── For each token, read vector and accumulate against shared weights ──
1336        // Token 0 (always valid: tok_count >= 1).
1337        {
1338            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1339            float4 v0 = *(device const float4*)(a0);
1340            float4 v1 = *(device const float4*)(a0 + 4);
1341            float4 v2 = *(device const float4*)(a0 + 8);
1342            float4 v3 = *(device const float4*)(a0 + 12);
1343            float4 v4 = *(device const float4*)(a0 + 16);
1344            float4 v5 = *(device const float4*)(a0 + 20);
1345            float4 v6 = *(device const float4*)(a0 + 24);
1346            float4 v7 = *(device const float4*)(a0 + 28);
1347            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1348                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1349            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1350                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1351            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1352                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1353            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1354                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1355            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1356        }
1357        // Token 1
1358        if (tok_count > 1) {
1359            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1360            float4 v0 = *(device const float4*)(a1);
1361            float4 v1 = *(device const float4*)(a1 + 4);
1362            float4 v2 = *(device const float4*)(a1 + 8);
1363            float4 v3 = *(device const float4*)(a1 + 12);
1364            float4 v4 = *(device const float4*)(a1 + 16);
1365            float4 v5 = *(device const float4*)(a1 + 20);
1366            float4 v6 = *(device const float4*)(a1 + 24);
1367            float4 v7 = *(device const float4*)(a1 + 28);
1368            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1369                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1370            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1371                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1372            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1373                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1374            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1375                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1376            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1377        }
1378        // Token 2
1379        if (tok_count > 2) {
1380            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1381            float4 v0 = *(device const float4*)(a2);
1382            float4 v1 = *(device const float4*)(a2 + 4);
1383            float4 v2 = *(device const float4*)(a2 + 8);
1384            float4 v3 = *(device const float4*)(a2 + 12);
1385            float4 v4 = *(device const float4*)(a2 + 16);
1386            float4 v5 = *(device const float4*)(a2 + 20);
1387            float4 v6 = *(device const float4*)(a2 + 24);
1388            float4 v7 = *(device const float4*)(a2 + 28);
1389            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1390                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1391            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1392                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1393            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1394                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1395            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1396                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1397            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1398        }
1399        // Token 3
1400        if (tok_count > 3) {
1401            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1402            float4 v0 = *(device const float4*)(a3);
1403            float4 v1 = *(device const float4*)(a3 + 4);
1404            float4 v2 = *(device const float4*)(a3 + 8);
1405            float4 v3 = *(device const float4*)(a3 + 12);
1406            float4 v4 = *(device const float4*)(a3 + 16);
1407            float4 v5 = *(device const float4*)(a3 + 20);
1408            float4 v6 = *(device const float4*)(a3 + 24);
1409            float4 v7 = *(device const float4*)(a3 + 28);
1410            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1411                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1412            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1413                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1414            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1415                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1416            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1417                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1418            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1419        }
1420    }
1421
1422    // simdgroup reduction
1423    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1424    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1425    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1426    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1427
1428    if (simd_lane == 0) {
1429        device float* o0 = outputs + (tok_base + 0) * rows;
1430        if (row_base     < rows) o0[row_base]     = s00;
1431        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1432        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1433        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1434
1435        if (tok_count > 1) {
1436            device float* o1 = outputs + (tok_base + 1) * rows;
1437            if (row_base     < rows) o1[row_base]     = s10;
1438            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1439            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1440            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1441        }
1442        if (tok_count > 2) {
1443            device float* o2 = outputs + (tok_base + 2) * rows;
1444            if (row_base     < rows) o2[row_base]     = s20;
1445            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1446            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1447            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1448        }
1449        if (tok_count > 3) {
1450            device float* o3 = outputs + (tok_base + 3) * rows;
1451            if (row_base     < rows) o3[row_base]     = s30;
1452            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1453            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1454            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1455        }
1456    }
1457}
1458
1459// ── matmul_q8_mma ──────────────────────────────────────────────────────
1460// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1461// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1462// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1463// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1464//
1465// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1466// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1467// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1468// dequantized into threadgroup memory once per block and reused by all
1469// simdgroups in the tile.
1470//
1471// Assumptions (verified in the dispatch helper, falls back otherwise):
1472//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1473//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1474//   * num_tokens may be any value; partial row at the tile boundary is handled
1475//     via a scratch copy path.
1476constant constexpr uint MMA_TOK_TILE = 16;
1477constant constexpr uint MMA_ROW_TILE = 16;
1478
1479kernel void matmul_q8_mma(
1480    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1481    device const float* inputs   [[buffer(1)]],  // [M, cols]
1482    device float* outputs        [[buffer(2)]],  // [M, rows]
1483    constant uint& num_tokens    [[buffer(3)]],
1484    constant uint& rows          [[buffer(4)]],
1485    constant uint& cols          [[buffer(5)]],
1486    uint2 tgid [[threadgroup_position_in_grid]],
1487    uint tid [[thread_index_in_threadgroup]],
1488    uint simd_id [[simdgroup_index_in_threadgroup]])
1489{
1490    uint row_base = tgid.x * MMA_ROW_TILE;
1491    uint tok_base = tgid.y * MMA_TOK_TILE;
1492    if (row_base >= rows || tok_base >= num_tokens) return;
1493
1494    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1495    threadgroup float w_tile[MMA_ROW_TILE * 32];
1496    threadgroup float t_tile[MMA_TOK_TILE * 32];
1497
1498    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1499    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1500    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1501
1502    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1503
1504    uint blocks_per_row = cols / 32;
1505    uint row_bytes = blocks_per_row * 34;
1506
1507    for (uint blk = 0; blk < blocks_per_row; blk++) {
1508        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1509        // 512 floats / 128 threads = 4 floats per thread.
1510        {
1511            uint base = tid * 4;
1512            for (uint ii = 0; ii < 4; ii++) {
1513                uint idx = base + ii;
1514                uint r = idx / 32;
1515                uint k = idx % 32;
1516                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1517                float sc = float(*(device const half*)rp);
1518                int ival = int(*(device const int8_t*)(rp + 2 + k));
1519                w_tile[r * 32 + k] = float(ival) * sc;
1520            }
1521        }
1522
1523        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1524        {
1525            uint base = tid * 4;
1526            for (uint ii = 0; ii < 4; ii++) {
1527                uint idx = base + ii;
1528                uint m = idx / 32;
1529                uint k = idx % 32;
1530                uint tok = tok_base + m;
1531                t_tile[m * 32 + k] = (tok < num_tokens)
1532                    ? inputs[tok * cols + blk * 32 + k]
1533                    : 0.0f;
1534            }
1535        }
1536
1537        threadgroup_barrier(mem_flags::mem_threadgroup);
1538
1539        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1540        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1541        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1542        // C[m, r] += A[m, k] * B[k, r]
1543        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1544            simdgroup_matrix<float, 8, 8> A, B;
1545            simdgroup_load(A,
1546                t_tile + sg_tok_base * 32 + k_sub * 8,
1547                32,
1548                ulong2(0, 0),
1549                false);
1550            simdgroup_load(B,
1551                w_tile + sg_row_base * 32 + k_sub * 8,
1552                32,
1553                ulong2(0, 0),
1554                true);
1555            simdgroup_multiply_accumulate(C, A, B, C);
1556        }
1557
1558        threadgroup_barrier(mem_flags::mem_threadgroup);
1559    }
1560
1561    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1562    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1563    uint out_tok = tok_base + sg_tok_base;
1564    uint out_row = row_base + sg_row_base;
1565    bool full_tok = (out_tok + 8 <= num_tokens);
1566    if (full_tok) {
1567        // Fast path: entire 8×8 sub-tile is in-bounds.
1568        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1569    } else if (out_tok < num_tokens) {
1570        // Partial row at the last token tile: stage in per-simdgroup scratch
1571        // and scalar-copy the valid rows.
1572        threadgroup float scratch[4 * 64];
1573        simdgroup_store(C, scratch + simd_id * 64, 8);
1574        simdgroup_barrier(mem_flags::mem_threadgroup);
1575        uint lane = tid % 32;
1576        if (lane == 0) {
1577            uint valid = num_tokens - out_tok;  // 1..7
1578            for (uint m = 0; m < valid; m++) {
1579                device float* dst = outputs + (out_tok + m) * rows + out_row;
1580                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1581                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1582                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1583            }
1584        }
1585    }
1586}
1587
1588// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1589// Larger-tile variant of matmul_q8_mma for long-context prefill.
1590//
1591// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1592// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1593// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1594//
1595//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1596//     output sub-tiles (tok, row):
1597//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1598//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1599//
1600// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1601// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1602// kernel — and halves the number of threadgroups vs the 16×16 tile.
1603//
1604// Assumptions (verified in dispatch helper, fallback otherwise):
1605//   * cols % 32 == 0
1606//   * rows % 32 == 0
1607constant constexpr uint MMA32_TOK_TILE = 32;
1608constant constexpr uint MMA32_ROW_TILE = 32;
1609
1610kernel void matmul_q8_mma32(
1611    device const uchar* matrix   [[buffer(0)]],
1612    device const float* inputs   [[buffer(1)]],
1613    device float* outputs        [[buffer(2)]],
1614    constant uint& num_tokens    [[buffer(3)]],
1615    constant uint& rows          [[buffer(4)]],
1616    constant uint& cols          [[buffer(5)]],
1617    uint2 tgid [[threadgroup_position_in_grid]],
1618    uint tid [[thread_index_in_threadgroup]],
1619    uint simd_id [[simdgroup_index_in_threadgroup]])
1620{
1621    uint row_base = tgid.x * MMA32_ROW_TILE;
1622    uint tok_base = tgid.y * MMA32_TOK_TILE;
1623    if (row_base >= rows || tok_base >= num_tokens) return;
1624
1625    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1626    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1627    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1628
1629    uint sg_tok_idx  = simd_id / 2;      // 0..3
1630    uint sg_row_half = simd_id % 2;      // 0..1
1631    uint sg_tok_base = sg_tok_idx * 8;
1632    uint sg_row_base_a = sg_row_half * 16 + 0;
1633    uint sg_row_base_b = sg_row_half * 16 + 8;
1634
1635    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1636    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1637
1638    uint blocks_per_row = cols / 32;
1639    uint row_bytes = blocks_per_row * 34;
1640
1641    for (uint blk = 0; blk < blocks_per_row; blk++) {
1642        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1643        {
1644            uint base = tid * 4;
1645            for (uint ii = 0; ii < 4; ii++) {
1646                uint idx = base + ii;
1647                uint r = idx / 32;
1648                uint k = idx % 32;
1649                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1650                float sc = float(*(device const half*)rp);
1651                int ival = int(*(device const int8_t*)(rp + 2 + k));
1652                w_tile[r * 32 + k] = float(ival) * sc;
1653            }
1654        }
1655
1656        // Cooperative token tile load.
1657        {
1658            uint base = tid * 4;
1659            for (uint ii = 0; ii < 4; ii++) {
1660                uint idx = base + ii;
1661                uint m = idx / 32;
1662                uint k = idx % 32;
1663                uint tok = tok_base + m;
1664                t_tile[m * 32 + k] = (tok < num_tokens)
1665                    ? inputs[tok * cols + blk * 32 + k]
1666                    : 0.0f;
1667            }
1668        }
1669
1670        threadgroup_barrier(mem_flags::mem_threadgroup);
1671
1672        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1673        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1674            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1675            simdgroup_load(A,
1676                t_tile + sg_tok_base * 32 + k_sub * 8,
1677                32,
1678                ulong2(0, 0),
1679                false);
1680            simdgroup_load(B_a,
1681                w_tile + sg_row_base_a * 32 + k_sub * 8,
1682                32,
1683                ulong2(0, 0),
1684                true);
1685            simdgroup_load(B_b,
1686                w_tile + sg_row_base_b * 32 + k_sub * 8,
1687                32,
1688                ulong2(0, 0),
1689                true);
1690            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1691            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1692        }
1693
1694        threadgroup_barrier(mem_flags::mem_threadgroup);
1695    }
1696
1697    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1698    // (verified in dispatch), so full simdgroup_store is safe for the row
1699    // dimension; only the last token tile may be partial.
1700    uint out_tok = tok_base + sg_tok_base;
1701    uint out_row_a = row_base + sg_row_base_a;
1702    uint out_row_b = row_base + sg_row_base_b;
1703    bool full_tok = (out_tok + 8 <= num_tokens);
1704    if (full_tok) {
1705        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1706        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1707    } else if (out_tok < num_tokens) {
1708        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1709        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1710        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1711        simdgroup_barrier(mem_flags::mem_threadgroup);
1712        uint lane = tid % 32;
1713        if (lane == 0) {
1714            uint valid = num_tokens - out_tok;  // 1..7
1715            for (uint m = 0; m < valid; m++) {
1716                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1717                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1718                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1719                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1720                for (uint j = 0; j < 8; j++) {
1721                    dst_a[j] = src_a[j];
1722                    dst_b[j] = src_b[j];
1723                }
1724            }
1725        }
1726    }
1727}
1728
1729// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1730// FP16 threadgroup-tile variant of matmul_q8_mma32.
1731//
1732// Stores dequantized weights and token inputs as `half` in threadgroup
1733// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1734// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1735// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1736// representation preserves the full quantized dynamic range.  Token
1737// activations stay numerically safe because the subsequent
1738// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1739//
1740// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1741// accumulators each.  Primary win vs mma32 is occupancy at moderate
1742// prefill lengths where the GPU is wave-starved.
1743kernel void matmul_q8_mma32_h(
1744    device const uchar* matrix   [[buffer(0)]],
1745    device const float* inputs   [[buffer(1)]],
1746    device float* outputs        [[buffer(2)]],
1747    constant uint& num_tokens    [[buffer(3)]],
1748    constant uint& rows          [[buffer(4)]],
1749    constant uint& cols          [[buffer(5)]],
1750    uint2 tgid [[threadgroup_position_in_grid]],
1751    uint tid [[thread_index_in_threadgroup]],
1752    uint simd_id [[simdgroup_index_in_threadgroup]])
1753{
1754    uint row_base = tgid.x * MMA32_ROW_TILE;
1755    uint tok_base = tgid.y * MMA32_TOK_TILE;
1756    if (row_base >= rows || tok_base >= num_tokens) return;
1757
1758    // 32×32 half tiles — 2 KB each, 4 KB total.
1759    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1760    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1761
1762    uint sg_tok_idx  = simd_id / 2;
1763    uint sg_row_half = simd_id % 2;
1764    uint sg_tok_base = sg_tok_idx * 8;
1765    uint sg_row_base_a = sg_row_half * 16 + 0;
1766    uint sg_row_base_b = sg_row_half * 16 + 8;
1767
1768    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1769    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1770
1771    uint blocks_per_row = cols / 32;
1772    uint row_bytes = blocks_per_row * 34;
1773
1774    for (uint blk = 0; blk < blocks_per_row; blk++) {
1775        // Cooperative weight dequantization to FP16.
1776        {
1777            uint base = tid * 4;
1778            for (uint ii = 0; ii < 4; ii++) {
1779                uint idx = base + ii;
1780                uint r = idx / 32;
1781                uint k = idx % 32;
1782                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1783                float sc = float(*(device const half*)rp);
1784                int ival = int(*(device const int8_t*)(rp + 2 + k));
1785                w_tile[r * 32 + k] = half(float(ival) * sc);
1786            }
1787        }
1788
1789        // Cooperative token tile load (f32 → f16 narrowing).
1790        {
1791            uint base = tid * 4;
1792            for (uint ii = 0; ii < 4; ii++) {
1793                uint idx = base + ii;
1794                uint m = idx / 32;
1795                uint k = idx % 32;
1796                uint tok = tok_base + m;
1797                t_tile[m * 32 + k] = (tok < num_tokens)
1798                    ? half(inputs[tok * cols + blk * 32 + k])
1799                    : half(0);
1800            }
1801        }
1802
1803        threadgroup_barrier(mem_flags::mem_threadgroup);
1804
1805        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1806            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1807            simdgroup_load(A,
1808                t_tile + sg_tok_base * 32 + k_sub * 8,
1809                32,
1810                ulong2(0, 0),
1811                false);
1812            simdgroup_load(B_a,
1813                w_tile + sg_row_base_a * 32 + k_sub * 8,
1814                32,
1815                ulong2(0, 0),
1816                true);
1817            simdgroup_load(B_b,
1818                w_tile + sg_row_base_b * 32 + k_sub * 8,
1819                32,
1820                ulong2(0, 0),
1821                true);
1822            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1823            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1824        }
1825
1826        threadgroup_barrier(mem_flags::mem_threadgroup);
1827    }
1828
1829    uint out_tok = tok_base + sg_tok_base;
1830    uint out_row_a = row_base + sg_row_base_a;
1831    uint out_row_b = row_base + sg_row_base_b;
1832    bool full_tok = (out_tok + 8 <= num_tokens);
1833    if (full_tok) {
1834        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1835        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1836    } else if (out_tok < num_tokens) {
1837        threadgroup float scratch[8 * 2 * 64];
1838        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1839        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1840        simdgroup_barrier(mem_flags::mem_threadgroup);
1841        uint lane = tid % 32;
1842        if (lane == 0) {
1843            uint valid = num_tokens - out_tok;
1844            for (uint m = 0; m < valid; m++) {
1845                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1846                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1847                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1848                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1849                for (uint j = 0; j < 8; j++) {
1850                    dst_a[j] = src_a[j];
1851                    dst_b[j] = src_b[j];
1852                }
1853            }
1854        }
1855    }
1856}
1857
1858// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1859// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1860//
1861// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1862// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1863//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1864//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1865// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1866//
1867// Per K_sub iteration: load two A fragments and two B fragments, then run
1868// **four** MMA instructions reusing A_top with both B's and A_bot with
1869// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1870// 2-accumulator kernel and halves the thread count per threadgroup (128
1871// threads), which often improves occupancy on Apple GPUs where the
1872// concurrent-thread budget is the tighter limit than shared-memory size.
1873kernel void matmul_q8_mma32_h4(
1874    device const uchar* matrix   [[buffer(0)]],
1875    device const float* inputs   [[buffer(1)]],
1876    device float* outputs        [[buffer(2)]],
1877    constant uint& num_tokens    [[buffer(3)]],
1878    constant uint& rows          [[buffer(4)]],
1879    constant uint& cols          [[buffer(5)]],
1880    uint2 tgid [[threadgroup_position_in_grid]],
1881    uint tid [[thread_index_in_threadgroup]],
1882    uint simd_id [[simdgroup_index_in_threadgroup]])
1883{
1884    uint row_base = tgid.x * MMA32_ROW_TILE;
1885    uint tok_base = tgid.y * MMA32_TOK_TILE;
1886    if (row_base >= rows || tok_base >= num_tokens) return;
1887
1888    // 32×32 FP16 tiles, 4 KB total.
1889    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1890    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1891
1892    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1893    uint sg_tok_q = simd_id / 2;   // 0..1
1894    uint sg_row_q = simd_id % 2;   // 0..1
1895    uint sg_tok_base = sg_tok_q * 16;
1896    uint sg_row_base = sg_row_q * 16;
1897
1898    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1899    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1900    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1901    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1902
1903    uint blocks_per_row = cols / 32;
1904    uint row_bytes = blocks_per_row * 34;
1905
1906    for (uint blk = 0; blk < blocks_per_row; blk++) {
1907        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1908        {
1909            uint base = tid * 8;
1910            for (uint ii = 0; ii < 8; ii++) {
1911                uint idx = base + ii;
1912                uint r = idx / 32;
1913                uint k = idx % 32;
1914                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1915                float sc = float(*(device const half*)rp);
1916                int ival = int(*(device const int8_t*)(rp + 2 + k));
1917                w_tile[r * 32 + k] = half(float(ival) * sc);
1918            }
1919        }
1920
1921        // Cooperative token tile load.
1922        {
1923            uint base = tid * 8;
1924            for (uint ii = 0; ii < 8; ii++) {
1925                uint idx = base + ii;
1926                uint m = idx / 32;
1927                uint k = idx % 32;
1928                uint tok = tok_base + m;
1929                t_tile[m * 32 + k] = (tok < num_tokens)
1930                    ? half(inputs[tok * cols + blk * 32 + k])
1931                    : half(0);
1932            }
1933        }
1934
1935        threadgroup_barrier(mem_flags::mem_threadgroup);
1936
1937        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1938        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1939            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1940            simdgroup_load(A_top,
1941                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1942                32,
1943                ulong2(0, 0),
1944                false);
1945            simdgroup_load(A_bot,
1946                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1947                32,
1948                ulong2(0, 0),
1949                false);
1950            simdgroup_load(B_lo,
1951                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1952                32,
1953                ulong2(0, 0),
1954                true);
1955            simdgroup_load(B_hi,
1956                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1957                32,
1958                ulong2(0, 0),
1959                true);
1960            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1961            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1962            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1963            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1964        }
1965
1966        threadgroup_barrier(mem_flags::mem_threadgroup);
1967    }
1968
1969    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
1970    uint out_tok_top = tok_base + sg_tok_base + 0;
1971    uint out_tok_bot = tok_base + sg_tok_base + 8;
1972    uint out_row_lo  = row_base + sg_row_base + 0;
1973    uint out_row_hi  = row_base + sg_row_base + 8;
1974    bool full = (out_tok_bot + 8 <= num_tokens);
1975    if (full) {
1976        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1977        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1978        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1979        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1980    } else {
1981        // Partial-token fallback via per-simdgroup scratch.
1982        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
1983        uint sg_base = simd_id * 256;
1984        simdgroup_store(C_00, scratch + sg_base +   0, 8);
1985        simdgroup_store(C_01, scratch + sg_base +  64, 8);
1986        simdgroup_store(C_10, scratch + sg_base + 128, 8);
1987        simdgroup_store(C_11, scratch + sg_base + 192, 8);
1988        simdgroup_barrier(mem_flags::mem_threadgroup);
1989        uint lane = tid % 32;
1990        if (lane == 0) {
1991            for (uint m = 0; m < 8; m++) {
1992                uint t_top = out_tok_top + m;
1993                if (t_top < num_tokens) {
1994                    device float* dst0 = outputs + t_top * rows + out_row_lo;
1995                    device float* dst1 = outputs + t_top * rows + out_row_hi;
1996                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
1997                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
1998                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
1999                }
2000                uint t_bot = out_tok_bot + m;
2001                if (t_bot < num_tokens) {
2002                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2003                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2004                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2005                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2006                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2007                }
2008            }
2009        }
2010    }
2011}
2012
2013// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2014// All-half MMA variant of matmul_q8_mma32_h4.
2015//
2016// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2017// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2018// rate (dual-issue FMA), so if Q8_0 precision holds through half
2019// accumulation this kernel can double the effective FLOP throughput on
2020// matmul-bound prefill.
2021//
2022// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2023// activations at each layer are bounded (post-RMSNorm ≈ O(1)).  Summing
2024// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2025// precision on extreme values, but the inputs are already quantized so the
2026// per-product error floor is higher than the half-precision rounding error.
2027// We verify correctness on 135M / 1B / 3B before enabling.
2028kernel void matmul_q8_mma32_hh4(
2029    device const uchar* matrix   [[buffer(0)]],
2030    device const float* inputs   [[buffer(1)]],
2031    device float* outputs        [[buffer(2)]],
2032    constant uint& num_tokens    [[buffer(3)]],
2033    constant uint& rows          [[buffer(4)]],
2034    constant uint& cols          [[buffer(5)]],
2035    uint2 tgid [[threadgroup_position_in_grid]],
2036    uint tid [[thread_index_in_threadgroup]],
2037    uint simd_id [[simdgroup_index_in_threadgroup]])
2038{
2039    uint row_base = tgid.x * MMA32_ROW_TILE;
2040    uint tok_base = tgid.y * MMA32_TOK_TILE;
2041    if (row_base >= rows || tok_base >= num_tokens) return;
2042
2043    threadgroup half w_tile[MMA32_ROW_TILE * 32];
2044    threadgroup half t_tile[MMA32_TOK_TILE * 32];
2045
2046    uint sg_tok_q = simd_id / 2;
2047    uint sg_row_q = simd_id % 2;
2048    uint sg_tok_base = sg_tok_q * 16;
2049    uint sg_row_base = sg_row_q * 16;
2050
2051    simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2052    simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2053    simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2054    simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2055
2056    uint blocks_per_row = cols / 32;
2057    uint row_bytes = blocks_per_row * 34;
2058
2059    for (uint blk = 0; blk < blocks_per_row; blk++) {
2060        {
2061            uint base = tid * 8;
2062            for (uint ii = 0; ii < 8; ii++) {
2063                uint idx = base + ii;
2064                uint r = idx / 32;
2065                uint k = idx % 32;
2066                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2067                float sc = float(*(device const half*)rp);
2068                int ival = int(*(device const int8_t*)(rp + 2 + k));
2069                w_tile[r * 32 + k] = half(float(ival) * sc);
2070            }
2071        }
2072        {
2073            uint base = tid * 8;
2074            for (uint ii = 0; ii < 8; ii++) {
2075                uint idx = base + ii;
2076                uint m = idx / 32;
2077                uint k = idx % 32;
2078                uint tok = tok_base + m;
2079                t_tile[m * 32 + k] = (tok < num_tokens)
2080                    ? half(inputs[tok * cols + blk * 32 + k])
2081                    : half(0);
2082            }
2083        }
2084        threadgroup_barrier(mem_flags::mem_threadgroup);
2085
2086        for (uint k_sub = 0; k_sub < 4; k_sub++) {
2087            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2088            simdgroup_load(A_top,
2089                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2090                32, ulong2(0, 0), false);
2091            simdgroup_load(A_bot,
2092                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2093                32, ulong2(0, 0), false);
2094            simdgroup_load(B_lo,
2095                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2096                32, ulong2(0, 0), true);
2097            simdgroup_load(B_hi,
2098                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2099                32, ulong2(0, 0), true);
2100            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2101            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2102            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2103            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2104        }
2105
2106        threadgroup_barrier(mem_flags::mem_threadgroup);
2107    }
2108
2109    // Store half accumulators via scratch (must widen to f32 for device output).
2110    uint out_tok_top = tok_base + sg_tok_base + 0;
2111    uint out_tok_bot = tok_base + sg_tok_base + 8;
2112    uint out_row_lo  = row_base + sg_row_base + 0;
2113    uint out_row_hi  = row_base + sg_row_base + 8;
2114
2115    threadgroup half scratch[4 * 4 * 64];
2116    uint sg_base = simd_id * 256;
2117    simdgroup_store(C_00, scratch + sg_base +   0, 8);
2118    simdgroup_store(C_01, scratch + sg_base +  64, 8);
2119    simdgroup_store(C_10, scratch + sg_base + 128, 8);
2120    simdgroup_store(C_11, scratch + sg_base + 192, 8);
2121    simdgroup_barrier(mem_flags::mem_threadgroup);
2122    uint lane = tid % 32;
2123    if (lane == 0) {
2124        for (uint m = 0; m < 8; m++) {
2125            uint t_top = out_tok_top + m;
2126            if (t_top < num_tokens) {
2127                device float* dst0 = outputs + t_top * rows + out_row_lo;
2128                device float* dst1 = outputs + t_top * rows + out_row_hi;
2129                threadgroup const half* src0 = scratch + sg_base +   0 + m * 8;
2130                threadgroup const half* src1 = scratch + sg_base +  64 + m * 8;
2131                for (uint j = 0; j < 8; j++) {
2132                    dst0[j] = float(src0[j]);
2133                    dst1[j] = float(src1[j]);
2134                }
2135            }
2136            uint t_bot = out_tok_bot + m;
2137            if (t_bot < num_tokens) {
2138                device float* dst2 = outputs + t_bot * rows + out_row_lo;
2139                device float* dst3 = outputs + t_bot * rows + out_row_hi;
2140                threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2141                threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2142                for (uint j = 0; j < 8; j++) {
2143                    dst2[j] = float(src2[j]);
2144                    dst3[j] = float(src3[j]);
2145                }
2146            }
2147        }
2148    }
2149}
2150
2151// ── add_bias_batch ─────────────────────────────────────────────────────
2152// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2153// Used for Qwen2 QKV bias after the fused qkv matmul.
2154//     out[token, i] += bias[i]    for i in 0..rows, token in 0..num_tokens
2155kernel void add_bias_batch(
2156    device float* out            [[buffer(0)]],  // [num_tokens, rows]
2157    device const float* bias     [[buffer(1)]],  // [rows]
2158    constant uint& num_tokens    [[buffer(2)]],
2159    constant uint& rows          [[buffer(3)]],
2160    uint id [[thread_position_in_grid]])
2161{
2162    uint total = num_tokens * rows;
2163    if (id >= total) return;
2164    uint i = id % rows;
2165    out[id] += bias[i];
2166}
2167
2168// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2169// Batched Q4_0 matrix-vector multiply for M input vectors.
2170// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2171kernel void matmul_vec_q4_batch(
2172    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2173    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2174    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2175    constant uint& num_tokens    [[buffer(3)]],  // M
2176    constant uint& rows          [[buffer(4)]],
2177    constant uint& cols          [[buffer(5)]],
2178    uint tgid [[threadgroup_position_in_grid]],
2179    uint tid [[thread_index_in_threadgroup]],
2180    uint simd_lane [[thread_index_in_simdgroup]],
2181    uint simd_id [[simdgroup_index_in_threadgroup]])
2182{
2183    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2184    uint token = tgid / row_tgs;
2185    uint tg_in_token = tgid % row_tgs;
2186    if (token >= num_tokens) return;
2187
2188    threadgroup float vec_tile[VEC_TILE_SIZE];
2189    device const float* input = inputs + token * cols;
2190    for (uint i = tid; i < cols; i += 256) {
2191        vec_tile[i] = input[i];
2192    }
2193    threadgroup_barrier(mem_flags::mem_threadgroup);
2194
2195    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2196    if (row_base >= rows) return;
2197
2198    uint blocks_per_row = cols / 32;
2199    uint row_bytes = blocks_per_row * 18;
2200
2201    device const uchar* r0 = matrix + row_base * row_bytes;
2202    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2203    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2204    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2205
2206    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2207
2208    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2209        uint bb = blk * 18;
2210        uint vb = blk * 32;
2211
2212        float sc0 = float(*(device const half*)(r0 + bb));
2213        float sc1 = float(*(device const half*)(r1 + bb));
2214        float sc2 = float(*(device const half*)(r2 + bb));
2215        float sc3 = float(*(device const half*)(r3 + bb));
2216
2217        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2218        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2219        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2220        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2221
2222        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2223        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2224        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2225        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2226        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2227        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2228        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2229        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2230
2231        float bd0=0, bd1=0, bd2=0, bd3=0;
2232        uchar4 b;
2233
2234        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2235        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2236        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2237        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2238
2239        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2240        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2241        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2242        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2243
2244        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2245        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2246        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2247        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2248
2249        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2250        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2251        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2252        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2253
2254        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2255    }
2256
2257    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2258    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2259
2260    device float* output = outputs + token * rows;
2261    if (simd_lane == 0) {
2262        if (row_base     < rows) output[row_base]     = sum0;
2263        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2264        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2265        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2266    }
2267}
2268
2269// ── copy_kv_batch ─────────────────────────────────────────────────────
2270// Copy K or V from a strided batch QKV buffer to the KV cache.
2271// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2272// dst layout: contiguous [max_seq, kv_dim] cache.
2273kernel void copy_kv_batch(
2274    device const float* src  [[buffer(0)]],  // batch QKV buffer
2275    device float* dst        [[buffer(1)]],  // KV cache
2276    constant uint& M         [[buffer(2)]],  // num tokens in batch
2277    constant uint& kv_dim    [[buffer(3)]],  // floats per KV vector
2278    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2279    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2280    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2281    uint id [[thread_position_in_grid]])
2282{
2283    uint total = M * kv_dim;
2284    if (id >= total) return;
2285    uint token = id / kv_dim;
2286    uint d = id % kv_dim;
2287    uint dst_off = (base_pos + token) * kv_dim + d;
2288    uint src_off = token * src_stride + src_offset + d;
2289    dst[dst_off] = src[src_off];
2290}
2291
2292// ── attention_batch ───────────────────────────────────────────────────
2293// Batched causal attention for prefill. Processes M tokens in one dispatch.
2294// Each threadgroup handles one (token, head) pair.
2295// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2296// Causal masking: token i can only attend to positions 0..base_pos+i.
2297kernel void attention_batch(
2298    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2299    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2300    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2301    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2302    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2303    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2304    constant uint& num_heads         [[buffer(6)]],
2305    constant uint& num_kv_heads      [[buffer(7)]],
2306    constant uint& head_dim          [[buffer(8)]],
2307    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2308    uint tgid [[threadgroup_position_in_grid]],
2309    uint tid [[thread_index_in_threadgroup]],
2310    uint simd_lane [[thread_index_in_simdgroup]],
2311    uint simd_id [[simdgroup_index_in_threadgroup]])
2312{
2313    // Grid: M * num_heads threadgroups
2314    uint token_idx = tgid / num_heads;
2315    uint head = tgid % num_heads;
2316    if (token_idx >= M) return;
2317
2318    uint kv_head = head / (num_heads / num_kv_heads);
2319    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2320
2321    // Q offset uses strided layout (from batch QKV buffer)
2322    uint q_off = token_idx * q_stride + head * head_dim;
2323    // Output is contiguous [M, num_heads * head_dim]
2324    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2325
2326    // Shared memory for attention scores
2327    threadgroup float scores[2048];
2328
2329    // Step 1: Q * K^T with simdgroup reduction
2330    for (uint s = simd_id; s < seq_len; s += 8) {
2331        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2332        float dot = 0.0;
2333        for (uint d = simd_lane; d < head_dim; d += 32) {
2334            dot += q_batch[q_off + d] * k_cache[k_off + d];
2335        }
2336        dot = simd_sum(dot);
2337        if (simd_lane == 0) {
2338            scores[s] = dot * fast::rsqrt(float(head_dim));
2339        }
2340    }
2341    threadgroup_barrier(mem_flags::mem_threadgroup);
2342
2343    // Step 2: Softmax (cooperative)
2344    float local_max = -INFINITY;
2345    for (uint s = tid; s < seq_len; s += 256) {
2346        local_max = max(local_max, scores[s]);
2347    }
2348    local_max = simd_max(local_max);
2349    threadgroup float shared_max[8];
2350    if (simd_lane == 0) shared_max[simd_id] = local_max;
2351    threadgroup_barrier(mem_flags::mem_threadgroup);
2352    if (tid == 0) {
2353        float m = shared_max[0];
2354        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2355        shared_max[0] = m;
2356    }
2357    threadgroup_barrier(mem_flags::mem_threadgroup);
2358    float max_val = shared_max[0];
2359
2360    float local_sum = 0.0;
2361    for (uint s = tid; s < seq_len; s += 256) {
2362        scores[s] = fast::exp(scores[s] - max_val);
2363        local_sum += scores[s];
2364    }
2365    local_sum = simd_sum(local_sum);
2366    threadgroup float shared_sum[8];
2367    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2368    threadgroup_barrier(mem_flags::mem_threadgroup);
2369    if (tid == 0) {
2370        float total = 0.0;
2371        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2372        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2373    }
2374    threadgroup_barrier(mem_flags::mem_threadgroup);
2375    float inv_sum = shared_sum[0];
2376    for (uint s = tid; s < seq_len; s += 256) {
2377        scores[s] *= inv_sum;
2378    }
2379    threadgroup_barrier(mem_flags::mem_threadgroup);
2380
2381    // Step 3: scores * V using float4 vectorized loads
2382    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2383    // This is much better than the scalar version where only 64 of 256 threads are active.
2384    uint v_stride = num_kv_heads * head_dim;
2385    uint head_dim4 = head_dim / 4;
2386    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2387        uint d = d4 * 4;
2388        float4 acc = float4(0.0);
2389        uint v_base = kv_head * head_dim + d;
2390        uint seq_len4 = seq_len & ~3u;
2391        for (uint s = 0; s < seq_len4; s += 4) {
2392            float sc0 = scores[s];
2393            float sc1 = scores[s + 1];
2394            float sc2 = scores[s + 2];
2395            float sc3 = scores[s + 3];
2396            acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
2397                 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
2398                 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
2399                 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
2400        }
2401        for (uint s = seq_len4; s < seq_len; s++) {
2402            acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
2403        }
2404        *(device float4*)(output_batch + out_off + d) = acc;
2405    }
2406    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2407    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2408        float acc = 0.0;
2409        uint v_base = kv_head * head_dim + d;
2410        for (uint s = 0; s < seq_len; s++) {
2411            acc += scores[s] * v_cache[s * v_stride + v_base];
2412        }
2413        output_batch[out_off + d] = acc;
2414    }
2415}
2416
2417// ── rope_qk_batch ─────────────────────────────────────────────────────
2418// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2419// launch + memory barrier per layer. Both Q and K live in the same
2420// qkv_data buffer at different offsets within each token's row.
2421// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2422kernel void rope_qk_batch(
2423    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2424    constant uint& M                 [[buffer(1)]],   // num tokens
2425    constant uint& base_pos          [[buffer(2)]],   // starting position
2426    constant uint& num_q_heads       [[buffer(3)]],
2427    constant uint& num_kv_heads      [[buffer(4)]],
2428    constant uint& head_dim          [[buffer(5)]],
2429    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2430    constant float& theta            [[buffer(7)]],
2431    uint id [[thread_position_in_grid]])
2432{
2433    uint half_dim = head_dim / 2;
2434    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2435    uint token = id / total_pairs;
2436    uint pair = id % total_pairs;
2437    if (token >= M) return;
2438
2439    uint pos = base_pos + token;
2440    uint q_pairs = num_q_heads * half_dim;
2441
2442    uint h, i, offset;
2443    if (pair < q_pairs) {
2444        // Q head
2445        h = pair / half_dim;
2446        i = pair % half_dim;
2447        offset = token * qkv_stride + h * head_dim + i * 2;
2448    } else {
2449        // K head
2450        uint kp = pair - q_pairs;
2451        h = kp / half_dim;
2452        i = kp % half_dim;
2453        uint k_start = num_q_heads * head_dim;
2454        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2455    }
2456
2457    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2458    float angle = float(pos) * freq;
2459    float cos_val = cos(angle);
2460    float sin_val = sin(angle);
2461
2462    float x0 = qkv_data[offset];
2463    float x1 = qkv_data[offset + 1];
2464    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2465    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2466}
2467
2468// ── copy_kv_both_batch ────────────────────────────────────────────────
2469// Fused K+V cache copy in a single dispatch: copies both K and V from
2470// the strided batch QKV buffer to their respective KV cache buffers.
2471// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2472kernel void copy_kv_both_batch(
2473    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride]
2474    device float* k_dst        [[buffer(1)]],  // K cache [max_seq, kv_dim]
2475    device float* v_dst        [[buffer(2)]],  // V cache [max_seq, kv_dim]
2476    constant uint& M           [[buffer(3)]],  // num tokens in batch
2477    constant uint& kv_dim      [[buffer(4)]],  // floats per KV vector
2478    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2479    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2480    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2481    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2482    uint id [[thread_position_in_grid]])
2483{
2484    // Total elements = M * kv_dim * 2 (K + V)
2485    uint total_kv = M * kv_dim;
2486    if (id >= total_kv * 2) return;
2487
2488    uint is_v = id / total_kv;        // 0 = K, 1 = V
2489    uint local_id = id % total_kv;
2490    uint token = local_id / kv_dim;
2491    uint d = local_id % kv_dim;
2492
2493    uint dst_off = (base_pos + token) * kv_dim + d;
2494    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2495
2496    if (is_v) {
2497        v_dst[dst_off] = src[src_off];
2498    } else {
2499        k_dst[dst_off] = src[src_off];
2500    }
2501}
2502"#
2503    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2504}
2505
2506// ---------------------------------------------------------------------------
2507// model.rs generation
2508// ---------------------------------------------------------------------------
2509
2510fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2511    let mut code = String::with_capacity(48 * 1024);
2512    emit_model_header(&mut code, config)?;
2513    emit_metal_model_struct(&mut code, config)?;
2514    emit_layer_buffers_struct(&mut code, config)?;
2515    emit_metal_model_impl(&mut code, config)?;
2516    emit_helper_functions(&mut code)?;
2517    Ok(code)
2518}
2519
2520fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2521    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2522    writeln!(
2523        code,
2524        "//! Model: {} ({} layers, hidden={})",
2525        config.architecture, config.num_layers, config.hidden_size
2526    )?;
2527    writeln!(code, "//!")?;
2528    writeln!(
2529        code,
2530        "//! Uses native Metal compute pipelines via the metal crate."
2531    )?;
2532    writeln!(code)?;
2533    writeln!(code, "#![allow(dead_code)]")?;
2534    writeln!(code)?;
2535    writeln!(code, "use metal::*;")?;
2536    writeln!(code, "#[allow(unused_imports)]")?;
2537    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2538    writeln!(code, "use std::mem;")?;
2539    writeln!(code)?;
2540
2541    // Model constants
2542    writeln!(
2543        code,
2544        "// ── Model constants ──────────────────────────────────"
2545    )?;
2546    writeln!(
2547        code,
2548        "pub const HIDDEN_SIZE: usize = {};",
2549        config.hidden_size
2550    )?;
2551    writeln!(
2552        code,
2553        "pub const INTERMEDIATE_SIZE: usize = {};",
2554        config.intermediate_size
2555    )?;
2556    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2557    writeln!(
2558        code,
2559        "pub const NUM_HEADS: usize = {};",
2560        config.num_attention_heads
2561    )?;
2562    writeln!(
2563        code,
2564        "pub const NUM_KV_HEADS: usize = {};",
2565        config.num_kv_heads
2566    )?;
2567    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
2568    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
2569    let effective_seq_len = config.max_seq_len.min(4096);
2570    writeln!(
2571        code,
2572        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
2573        effective_seq_len, config.max_seq_len
2574    )?;
2575    writeln!(
2576        code,
2577        "pub const RMS_NORM_EPS: f32 = {:e};",
2578        config.rms_norm_eps
2579    )?;
2580    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
2581    writeln!(
2582        code,
2583        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
2584    )?;
2585    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
2586    writeln!(code)?;
2587
2588    Ok(())
2589}
2590
2591fn emit_metal_model_struct(
2592    code: &mut String,
2593    config: &ModelConfig,
2594) -> Result<(), MetalCodegenError> {
2595    writeln!(
2596        code,
2597        "// ── MetalModel ──────────────────────────────────────────"
2598    )?;
2599    writeln!(code)?;
2600    writeln!(
2601        code,
2602        "/// Metal-accelerated transformer model for Apple Silicon."
2603    )?;
2604    writeln!(code, "///")?;
2605    writeln!(
2606        code,
2607        "/// Uses unified memory for zero-copy weight access and native Metal"
2608    )?;
2609    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
2610    writeln!(code, "pub struct MetalModel {{")?;
2611    writeln!(code, "    device: Device,")?;
2612    writeln!(code, "    queue: CommandQueue,")?;
2613    writeln!(code)?;
2614    writeln!(code, "    // ── Compute pipelines ──")?;
2615    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
2616    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
2617    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
2618    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
2619    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
2620    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
2621    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
2622    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
2623    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
2624    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
2625    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
2626    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
2627    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
2628    writeln!(code)?;
2629    writeln!(code, "    // ── Batched prefill pipelines ──")?;
2630    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
2631    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
2632    writeln!(
2633        code,
2634        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
2635    )?;
2636    writeln!(
2637        code,
2638        "    matmul_q8_mma_pipeline: ComputePipelineState,"
2639    )?;
2640    writeln!(
2641        code,
2642        "    matmul_q8_mma32_pipeline: ComputePipelineState,"
2643    )?;
2644    writeln!(
2645        code,
2646        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
2647    )?;
2648    writeln!(
2649        code,
2650        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
2651    )?;
2652    writeln!(
2653        code,
2654        "    matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
2655    )?;
2656    if config.qkv_bias {
2657        writeln!(
2658            code,
2659            "    add_bias_batch_pipeline: ComputePipelineState,"
2660        )?;
2661    }
2662    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
2663    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
2664    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
2665    writeln!(
2666        code,
2667        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
2668    )?;
2669    writeln!(
2670        code,
2671        "    add_inplace_batch_pipeline: ComputePipelineState,"
2672    )?;
2673    writeln!(
2674        code,
2675        "    copy_embedding_batch_pipeline: ComputePipelineState,"
2676    )?;
2677    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
2678    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
2679    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
2680    writeln!(
2681        code,
2682        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
2683    )?;
2684    writeln!(code)?;
2685    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
2686    writeln!(
2687        code,
2688        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
2689    )?;
2690    writeln!(code, "    embed_buf: Buffer,")?;
2691    writeln!(code)?;
2692    writeln!(code, "    /// Per-layer weight buffers")?;
2693    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
2694    writeln!(code)?;
2695    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
2696    writeln!(code, "    norm_buf: Buffer,")?;
2697    writeln!(code)?;
2698    writeln!(
2699        code,
2700        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
2701    )?;
2702    writeln!(code, "    lm_head_buf: Buffer,")?;
2703    writeln!(code)?;
2704    writeln!(
2705        code,
2706        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
2707    )?;
2708    writeln!(code, "    hidden_buf: Buffer,")?;
2709    writeln!(code, "    residual_buf: Buffer,")?;
2710    writeln!(code, "    normed_buf: Buffer,")?;
2711    writeln!(
2712        code,
2713        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
2714    )?;
2715    writeln!(code, "    qkv_buf: Buffer,")?;
2716    writeln!(code, "    attn_out_buf: Buffer,")?;
2717    writeln!(code, "    attn_proj_buf: Buffer,")?;
2718    writeln!(
2719        code,
2720        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
2721    )?;
2722    writeln!(code, "    gate_up_buf: Buffer,")?;
2723    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
2724    writeln!(code, "    ffn_out_buf: Buffer,")?;
2725    writeln!(code, "    add_tmp_buf: Buffer,")?;
2726    writeln!(code, "    logits_buf: Buffer,")?;
2727    writeln!(code)?;
2728    writeln!(code, "    // ── Batched prefill working buffers ──")?;
2729    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
2730    writeln!(code, "    batch_hidden_buf: Buffer,")?;
2731    writeln!(
2732        code,
2733        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
2734    )?;
2735    writeln!(code, "    batch_residual_buf: Buffer,")?;
2736    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
2737    writeln!(code, "    batch_qkv_buf: Buffer,")?;
2738    writeln!(
2739        code,
2740        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
2741    )?;
2742    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
2743    writeln!(
2744        code,
2745        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
2746    )?;
2747    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
2748    writeln!(
2749        code,
2750        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
2751    )?;
2752    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
2753    writeln!(
2754        code,
2755        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
2756    )?;
2757    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
2758    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
2759    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
2760    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
2761    writeln!(code, "    batch_tokens_buf: Buffer,")?;
2762    writeln!(code, "    /// Positions buffer for batch RoPE")?;
2763    writeln!(code, "    batch_positions_buf: Buffer,")?;
2764    writeln!(code)?;
2765    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
2766    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
2767    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
2768    writeln!(code)?;
2769    writeln!(code, "    // ── Inference state ──")?;
2770    writeln!(code, "    pos: usize,")?;
2771    writeln!(code)?;
2772    writeln!(
2773        code,
2774        "    /// Previous command buffer for double-buffered prefill."
2775    )?;
2776    writeln!(
2777        code,
2778        "    /// While the GPU executes token N, the CPU can encode token N+1."
2779    )?;
2780    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
2781    writeln!(code, "}}")?;
2782    writeln!(code)?;
2783
2784    Ok(())
2785}
2786
2787fn emit_layer_buffers_struct(
2788    code: &mut String,
2789    config: &ModelConfig,
2790) -> Result<(), MetalCodegenError> {
2791    writeln!(
2792        code,
2793        "/// Per-layer weight buffers for attention and FFN projections."
2794    )?;
2795    writeln!(code, "struct LayerBuffers {{")?;
2796    writeln!(code, "    attn_norm: Buffer,")?;
2797    writeln!(
2798        code,
2799        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
2800    )?;
2801    writeln!(code, "    qkv_weight: Buffer,")?;
2802    if config.qkv_bias {
2803        writeln!(
2804            code,
2805            "    /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
2806        )?;
2807        writeln!(code, "    qkv_bias: Buffer,")?;
2808    }
2809    writeln!(code, "    o_weight: Buffer,")?;
2810    writeln!(code, "    ffn_norm: Buffer,")?;
2811    writeln!(
2812        code,
2813        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
2814    )?;
2815    writeln!(code, "    gate_up_weight: Buffer,")?;
2816    writeln!(code, "    down_weight: Buffer,")?;
2817    writeln!(code, "}}")?;
2818    writeln!(code)?;
2819
2820    Ok(())
2821}
2822
2823fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2824    let hidden = config.hidden_size;
2825    let intermediate = config.intermediate_size;
2826    let _num_layers = config.num_layers;
2827    let num_heads = config.num_attention_heads;
2828    let num_kv_heads = config.num_kv_heads;
2829    let head_dim = config.head_dim;
2830    let vocab = config.vocab_size;
2831    let effective_seq_len = config.max_seq_len.min(4096);
2832    let is_q8 = config.dtype == DType::Q8_0;
2833    let is_q4 = config.dtype == DType::Q4_0;
2834    let kv_dim = num_kv_heads * head_dim;
2835
2836    writeln!(code, "impl MetalModel {{")?;
2837
2838    // ── new() ──
2839    writeln!(
2840        code,
2841        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
2842    )?;
2843    writeln!(code, "    ///")?;
2844    writeln!(
2845        code,
2846        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
2847    )?;
2848    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
2849    writeln!(
2850        code,
2851        "        let device = Device::system_default().expect(\"no Metal device found\");"
2852    )?;
2853    writeln!(code, "        let queue = device.new_command_queue();")?;
2854    writeln!(code)?;
2855
2856    // Compile shaders
2857    writeln!(
2858        code,
2859        "        // Compile Metal shaders from embedded source"
2860    )?;
2861    writeln!(
2862        code,
2863        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
2864    )?;
2865    writeln!(
2866        code,
2867        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
2868    )?;
2869    writeln!(
2870        code,
2871        "            .expect(\"failed to compile Metal shaders\");"
2872    )?;
2873    writeln!(code)?;
2874
2875    // Create compute pipelines
2876    writeln!(code, "        // Create compute pipelines")?;
2877    for (var, fn_name) in [
2878        ("matmul_pipeline", "matmul_vec"),
2879        ("matmul_q8_pipeline", "matmul_vec_q8"),
2880        ("matmul_q4_pipeline", "matmul_vec_q4"),
2881        ("rms_norm_pipeline", "rms_norm"),
2882        ("rope_pipeline", "rope"),
2883        ("softmax_pipeline", "softmax"),
2884        ("silu_mul_pipeline", "silu_mul"),
2885        ("silu_mul_fused_pipeline", "silu_mul_fused"),
2886        ("add_pipeline", "elementwise_add"),
2887        ("attention_pipeline", "attention"),
2888        ("add_inplace_pipeline", "add_inplace"),
2889        ("copy_pipeline", "copy_buffer"),
2890        ("copy_offset_pipeline", "copy_offset"),
2891        ("matmul_batch_pipeline", "matmul_vec_batch"),
2892        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
2893        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
2894        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
2895        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
2896        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
2897        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
2898        ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
2899        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
2900        ("rms_norm_batch_pipeline", "rms_norm_batch"),
2901        ("rope_batch_pipeline", "rope_batch"),
2902        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
2903        ("add_inplace_batch_pipeline", "add_inplace_batch"),
2904        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
2905        ("attention_batch_pipeline", "attention_batch"),
2906        ("copy_kv_batch_pipeline", "copy_kv_batch"),
2907        ("rope_qk_batch_pipeline", "rope_qk_batch"),
2908        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
2909    ] {
2910        writeln!(
2911            code,
2912            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
2913        )?;
2914    }
2915    if config.qkv_bias {
2916        writeln!(
2917            code,
2918            "        let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
2919        )?;
2920    }
2921    writeln!(code)?;
2922
2923    // Weight loading
2924    writeln!(
2925        code,
2926        "        // Load weights into Metal shared-memory buffers"
2927    )?;
2928    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
2929    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
2930    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
2931    writeln!(code)?;
2932    writeln!(
2933        code,
2934        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
2935    )?;
2936    writeln!(code)?;
2937    writeln!(
2938        code,
2939        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
2940    )?;
2941    writeln!(
2942        code,
2943        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
2944    )?;
2945    writeln!(code, "            let byte_len = n * f32_size;")?;
2946    writeln!(code, "            let cur = cursor.get();")?;
2947    writeln!(
2948        code,
2949        "            let data = &weights[cur..cur + byte_len];"
2950    )?;
2951    writeln!(code, "            cursor.set(cur + byte_len);")?;
2952    writeln!(code, "            device.new_buffer_with_data(")?;
2953    writeln!(code, "                data.as_ptr() as *const _,")?;
2954    writeln!(code, "                byte_len as u64,")?;
2955    writeln!(
2956        code,
2957        "                MTLResourceOptions::StorageModeShared,"
2958    )?;
2959    writeln!(code, "            )")?;
2960    writeln!(code, "        }};")?;
2961    writeln!(code)?;
2962
2963    if is_q8 {
2964        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
2965        // We load them directly into Metal buffers without dequantizing,
2966        // and use the matmul_vec_q8 shader that operates on quantized data.
2967        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
2968        writeln!(
2969            code,
2970            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
2971        )?;
2972        writeln!(
2973            code,
2974            "        // as raw bytes into a Metal buffer (no dequantization)."
2975        )?;
2976        writeln!(
2977            code,
2978            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
2979        )?;
2980        writeln!(
2981            code,
2982            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
2983        )?;
2984        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2985        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
2986        writeln!(code, "            let total_raw = rows * row_bytes;")?;
2987        writeln!(code, "            let cur = cursor.get();")?;
2988        writeln!(
2989            code,
2990            "            let data = &weights[cur..cur + total_raw];"
2991        )?;
2992        writeln!(code, "            cursor.set(cur + total_raw);")?;
2993        writeln!(code, "            device.new_buffer_with_data(")?;
2994        writeln!(code, "                data.as_ptr() as *const _,")?;
2995        writeln!(code, "                total_raw as u64,")?;
2996        writeln!(
2997            code,
2998            "                MTLResourceOptions::StorageModeShared,"
2999        )?;
3000        writeln!(code, "            )")?;
3001        writeln!(code, "        }};")?;
3002        writeln!(code)?;
3003        writeln!(
3004            code,
3005            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3006        )?;
3007        writeln!(
3008            code,
3009            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3010        )?;
3011        writeln!(
3012            code,
3013            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3014        )?;
3015        writeln!(
3016            code,
3017            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3018        )?;
3019        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3020        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
3021        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3022        writeln!(code, "            let cur = cursor.get();")?;
3023        writeln!(
3024            code,
3025            "            let data = &weights[cur..cur + total_raw];"
3026        )?;
3027        writeln!(code, "            cursor.set(cur + total_raw);")?;
3028        writeln!(code, "            device.new_buffer_with_data(")?;
3029        writeln!(code, "                data.as_ptr() as *const _,")?;
3030        writeln!(code, "                total_raw as u64,")?;
3031        writeln!(
3032            code,
3033            "                MTLResourceOptions::StorageModeShared,"
3034        )?;
3035        writeln!(code, "            )")?;
3036        writeln!(code, "        }};")?;
3037        writeln!(code)?;
3038    }
3039
3040    if is_q4 {
3041        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
3042        // We load them directly into Metal buffers without dequantizing,
3043        // and use the matmul_vec_q4 shader that operates on quantized data.
3044        // This quarters GPU memory usage vs f32 dequantization.
3045        writeln!(
3046            code,
3047            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3048        )?;
3049        writeln!(
3050            code,
3051            "        // as raw bytes into a Metal buffer (no dequantization)."
3052        )?;
3053        writeln!(
3054            code,
3055            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3056        )?;
3057        writeln!(
3058            code,
3059            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3060        )?;
3061        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3062        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3063        writeln!(code, "            let total_raw = rows * row_bytes;")?;
3064        writeln!(code, "            let cur = cursor.get();")?;
3065        writeln!(
3066            code,
3067            "            let data = &weights[cur..cur + total_raw];"
3068        )?;
3069        writeln!(code, "            cursor.set(cur + total_raw);")?;
3070        writeln!(code, "            device.new_buffer_with_data(")?;
3071        writeln!(code, "                data.as_ptr() as *const _,")?;
3072        writeln!(code, "                total_raw as u64,")?;
3073        writeln!(
3074            code,
3075            "                MTLResourceOptions::StorageModeShared,"
3076        )?;
3077        writeln!(code, "            )")?;
3078        writeln!(code, "        }};")?;
3079        writeln!(code)?;
3080        writeln!(
3081            code,
3082            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3083        )?;
3084        writeln!(
3085            code,
3086            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3087        )?;
3088        writeln!(
3089            code,
3090            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3091        )?;
3092        writeln!(
3093            code,
3094            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3095        )?;
3096        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
3097        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
3098        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
3099        writeln!(code, "            let cur = cursor.get();")?;
3100        writeln!(
3101            code,
3102            "            let data = &weights[cur..cur + total_raw];"
3103        )?;
3104        writeln!(code, "            cursor.set(cur + total_raw);")?;
3105        writeln!(code, "            device.new_buffer_with_data(")?;
3106        writeln!(code, "                data.as_ptr() as *const _,")?;
3107        writeln!(code, "                total_raw as u64,")?;
3108        writeln!(
3109            code,
3110            "                MTLResourceOptions::StorageModeShared,"
3111        )?;
3112        writeln!(code, "            )")?;
3113        writeln!(code, "        }};")?;
3114        writeln!(code)?;
3115    }
3116
3117    writeln!(
3118        code,
3119        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
3120    )?;
3121    writeln!(code)?;
3122
3123    // Per-layer weights
3124    writeln!(
3125        code,
3126        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3127    )?;
3128    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
3129
3130    // attn_norm is always f32
3131    writeln!(
3132        code,
3133        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
3134    )?;
3135
3136    let qkv_rows = hidden + 2 * kv_dim;
3137    if is_q8 {
3138        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
3139        writeln!(
3140            code,
3141            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3142        )?;
3143        if config.qkv_bias {
3144            writeln!(
3145                code,
3146                "            // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3147            )?;
3148            writeln!(
3149                code,
3150                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3151            )?;
3152        }
3153        writeln!(
3154            code,
3155            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3156        )?;
3157    } else if is_q4 {
3158        writeln!(
3159            code,
3160            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3161        )?;
3162        if config.qkv_bias {
3163            writeln!(
3164                code,
3165                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3166            )?;
3167        }
3168        writeln!(
3169            code,
3170            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3171        )?;
3172    } else {
3173        writeln!(
3174            code,
3175            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3176        )?;
3177        if config.qkv_bias {
3178            writeln!(
3179                code,
3180                "            let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3181            )?;
3182        }
3183        writeln!(
3184            code,
3185            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3186        )?;
3187    }
3188
3189    // ffn_norm is always f32
3190    writeln!(
3191        code,
3192        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3193    )?;
3194
3195    let gate_up_rows = 2 * intermediate;
3196    if is_q8 {
3197        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
3198        writeln!(
3199            code,
3200            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3201        )?;
3202        writeln!(
3203            code,
3204            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3205        )?;
3206    } else if is_q4 {
3207        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3208        writeln!(
3209            code,
3210            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3211        )?;
3212        writeln!(
3213            code,
3214            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3215        )?;
3216    } else {
3217        // Fused gate+up weight: read both as a single contiguous f32 buffer
3218        writeln!(
3219            code,
3220            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3221        )?;
3222        writeln!(
3223            code,
3224            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3225        )?;
3226    }
3227
3228    writeln!(code, "            layers.push(LayerBuffers {{")?;
3229    writeln!(code, "                attn_norm,")?;
3230    writeln!(code, "                qkv_weight,")?;
3231    if config.qkv_bias {
3232        writeln!(code, "                qkv_bias,")?;
3233    }
3234    writeln!(code, "                o_weight,")?;
3235    writeln!(code, "                ffn_norm,")?;
3236    writeln!(code, "                gate_up_weight,")?;
3237    writeln!(code, "                down_weight,")?;
3238    writeln!(code, "            }});")?;
3239    writeln!(code, "        }}")?;
3240    writeln!(code)?;
3241
3242    // final_norm is always f32
3243    writeln!(
3244        code,
3245        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3246    )?;
3247    writeln!(code)?;
3248
3249    // lm_head
3250    if is_q8 {
3251        writeln!(
3252            code,
3253            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3254        )?;
3255    } else if is_q4 {
3256        writeln!(
3257            code,
3258            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3259        )?;
3260    } else {
3261        writeln!(
3262            code,
3263            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3264        )?;
3265    }
3266    writeln!(code)?;
3267
3268    // Working buffers
3269    let hidden_bytes = hidden * 4;
3270    let _kv_dim_bytes = kv_dim * 4;
3271    let intermediate_bytes = intermediate * 4;
3272    let vocab_bytes = vocab * 4;
3273    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
3274
3275    writeln!(
3276        code,
3277        "        // Allocate working buffers (shared memory for zero-copy)"
3278    )?;
3279    writeln!(
3280        code,
3281        "        let opts = MTLResourceOptions::StorageModeShared;"
3282    )?;
3283    writeln!(
3284        code,
3285        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3286    )?;
3287    writeln!(
3288        code,
3289        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3290    )?;
3291    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3292    writeln!(
3293        code,
3294        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3295    )?;
3296    writeln!(
3297        code,
3298        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3299    )?;
3300    writeln!(
3301        code,
3302        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3303    )?;
3304    writeln!(
3305        code,
3306        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3307    )?;
3308    writeln!(
3309        code,
3310        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3311    )?;
3312    let gate_up_buf_bytes = 2 * intermediate * 4;
3313    writeln!(
3314        code,
3315        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3316    )?;
3317    writeln!(
3318        code,
3319        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3320    )?;
3321    writeln!(
3322        code,
3323        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3324    )?;
3325    writeln!(
3326        code,
3327        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3328    )?;
3329    writeln!(
3330        code,
3331        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3332    )?;
3333    writeln!(
3334        code,
3335        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3336    )?;
3337    writeln!(code)?;
3338
3339    // Batch prefill working buffers
3340    let batch_hidden_bytes = hidden * 4; // per-token
3341    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3342    let batch_gate_up_bytes = 2 * intermediate * 4;
3343    let batch_intermediate_bytes = intermediate * 4;
3344    writeln!(
3345        code,
3346        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3347    )?;
3348    writeln!(
3349        code,
3350        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3351    )?;
3352    writeln!(
3353        code,
3354        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3355    )?;
3356    writeln!(
3357        code,
3358        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3359    )?;
3360    writeln!(
3361        code,
3362        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3363    )?;
3364    writeln!(
3365        code,
3366        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3367    )?;
3368    writeln!(
3369        code,
3370        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3371    )?;
3372    writeln!(
3373        code,
3374        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3375    )?;
3376    writeln!(
3377        code,
3378        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3379    )?;
3380    writeln!(
3381        code,
3382        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3383    )?;
3384    writeln!(
3385        code,
3386        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3387    )?;
3388    writeln!(code)?;
3389
3390    // KV cache buffers
3391    writeln!(code, "        // KV cache buffers (per-layer)")?;
3392    writeln!(
3393        code,
3394        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3395    )?;
3396    writeln!(
3397        code,
3398        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3399    )?;
3400    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3401    writeln!(
3402        code,
3403        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3404    )?;
3405    writeln!(
3406        code,
3407        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3408    )?;
3409    writeln!(code, "        }}")?;
3410    writeln!(code)?;
3411
3412    writeln!(code, "        Self {{")?;
3413    writeln!(code, "            device,")?;
3414    writeln!(code, "            queue,")?;
3415    writeln!(code, "            matmul_pipeline,")?;
3416    writeln!(code, "            matmul_q8_pipeline,")?;
3417    writeln!(code, "            matmul_q4_pipeline,")?;
3418    writeln!(code, "            rms_norm_pipeline,")?;
3419    writeln!(code, "            rope_pipeline,")?;
3420    writeln!(code, "            softmax_pipeline,")?;
3421    writeln!(code, "            silu_mul_pipeline,")?;
3422    writeln!(code, "            silu_mul_fused_pipeline,")?;
3423    writeln!(code, "            add_pipeline,")?;
3424    writeln!(code, "            attention_pipeline,")?;
3425    writeln!(code, "            add_inplace_pipeline,")?;
3426    writeln!(code, "            copy_pipeline,")?;
3427    writeln!(code, "            copy_offset_pipeline,")?;
3428    writeln!(code, "            matmul_batch_pipeline,")?;
3429    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3430    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3431    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3432    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3433    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3434    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3435    writeln!(code, "            matmul_q8_mma32_hh4_pipeline,")?;
3436    if config.qkv_bias {
3437        writeln!(code, "            add_bias_batch_pipeline,")?;
3438    }
3439    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3440    writeln!(code, "            rms_norm_batch_pipeline,")?;
3441    writeln!(code, "            rope_batch_pipeline,")?;
3442    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3443    writeln!(code, "            add_inplace_batch_pipeline,")?;
3444    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3445    writeln!(code, "            attention_batch_pipeline,")?;
3446    writeln!(code, "            copy_kv_batch_pipeline,")?;
3447    writeln!(code, "            rope_qk_batch_pipeline,")?;
3448    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3449    writeln!(code, "            embed_buf,")?;
3450    writeln!(code, "            layers,")?;
3451    writeln!(code, "            norm_buf,")?;
3452    writeln!(code, "            lm_head_buf,")?;
3453    writeln!(code, "            hidden_buf,")?;
3454    writeln!(code, "            residual_buf,")?;
3455    writeln!(code, "            normed_buf,")?;
3456    writeln!(code, "            qkv_buf,")?;
3457    writeln!(code, "            attn_out_buf,")?;
3458    writeln!(code, "            attn_proj_buf,")?;
3459    writeln!(code, "            gate_up_buf,")?;
3460    writeln!(code, "            ffn_hidden_buf,")?;
3461    writeln!(code, "            ffn_out_buf,")?;
3462    writeln!(code, "            add_tmp_buf,")?;
3463    writeln!(code, "            logits_buf,")?;
3464    writeln!(code, "            batch_hidden_buf,")?;
3465    writeln!(code, "            batch_residual_buf,")?;
3466    writeln!(code, "            batch_qkv_buf,")?;
3467    writeln!(code, "            batch_attn_out_buf,")?;
3468    writeln!(code, "            batch_attn_proj_buf,")?;
3469    writeln!(code, "            batch_gate_up_buf,")?;
3470    writeln!(code, "            batch_ffn_hidden_buf,")?;
3471    writeln!(code, "            batch_ffn_out_buf,")?;
3472    writeln!(code, "            batch_tokens_buf,")?;
3473    writeln!(code, "            batch_positions_buf,")?;
3474    writeln!(code, "            k_cache,")?;
3475    writeln!(code, "            v_cache,")?;
3476    writeln!(code, "            pos: 0,")?;
3477    writeln!(code, "            prev_cmd: None,")?;
3478    writeln!(code, "        }}")?;
3479    writeln!(code, "    }}")?;
3480    writeln!(code)?;
3481
3482    // ── forward() ──
3483    writeln!(
3484        code,
3485        "    /// Run the forward pass for a single token at the current position."
3486    )?;
3487    writeln!(code, "    ///")?;
3488    writeln!(
3489        code,
3490        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3491    )?;
3492    writeln!(code, "    ///")?;
3493    writeln!(
3494        code,
3495        "    /// All GPU operations are encoded into a single command buffer and"
3496    )?;
3497    writeln!(
3498        code,
3499        "    /// committed once at the end, avoiding per-operation synchronization."
3500    )?;
3501    writeln!(
3502        code,
3503        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3504    )?;
3505    writeln!(
3506        code,
3507        "        // Wait for any pending prefill command buffer"
3508    )?;
3509    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3510    writeln!(code, "            prev.wait_until_completed();")?;
3511    writeln!(code, "        }}")?;
3512    writeln!(code)?;
3513    writeln!(code, "        let pos = self.pos;")?;
3514    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3515    writeln!(code)?;
3516
3517    // Single compute encoder for the entire forward pass — no blit encoder
3518    // transitions. Copy operations use compute copy kernels instead of blits.
3519    let matmul_fn = if is_q8 {
3520        "dispatch_matmul_q8"
3521    } else if is_q4 {
3522        "dispatch_matmul_q4"
3523    } else {
3524        "dispatch_matmul"
3525    };
3526
3527    writeln!(
3528        code,
3529        "        // Single compute encoder for the entire forward pass (no blit transitions)"
3530    )?;
3531    writeln!(code, "        {{")?;
3532    writeln!(
3533        code,
3534        "            let enc = cmd.new_compute_command_encoder();"
3535    )?;
3536    writeln!(code)?;
3537
3538    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
3539    writeln!(
3540        code,
3541        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3542    )?;
3543    writeln!(
3544        code,
3545        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3546    )?;
3547    writeln!(
3548        code,
3549        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3550    )?;
3551    writeln!(
3552        code,
3553        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
3554        hidden * 4,
3555    )?;
3556    writeln!(code, "            unsafe {{")?;
3557    writeln!(
3558        code,
3559        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
3560    )?;
3561    writeln!(
3562        code,
3563        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3564    )?;
3565    writeln!(
3566        code,
3567        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
3568    )?;
3569    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
3570    writeln!(
3571        code,
3572        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3573    )?;
3574    writeln!(code, "                    hidden_ptr,")?;
3575    writeln!(code, "                    HIDDEN_SIZE,")?;
3576    writeln!(code, "                );")?;
3577    writeln!(
3578        code,
3579        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3580    )?;
3581    writeln!(code, "            }}")?;
3582    writeln!(code)?;
3583
3584    // 2. Transformer layers
3585    writeln!(code, "            // 2. Transformer layers")?;
3586    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3587    writeln!(code)?;
3588    let q_byte_offset = 0usize;
3589    let k_byte_offset = hidden * 4;
3590    let v_byte_offset = (hidden + kv_dim) * 4;
3591
3592    writeln!(
3593        code,
3594        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
3595    )?;
3596    writeln!(
3597        code,
3598        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3599    )?;
3600    writeln!(
3601        code,
3602        "                // Fused Q+K+V matmul: single dispatch for all three projections"
3603    )?;
3604    writeln!(
3605        code,
3606        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3607    )?;
3608    if config.qkv_bias {
3609        writeln!(
3610            code,
3611            "                // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
3612        )?;
3613        writeln!(
3614            code,
3615            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
3616        )?;
3617    }
3618    writeln!(
3619        code,
3620        "                // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
3621    )?;
3622    writeln!(
3623        code,
3624        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3625    )?;
3626    writeln!(
3627        code,
3628        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3629    )?;
3630    writeln!(code)?;
3631    writeln!(
3632        code,
3633        "                // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
3634    )?;
3635    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
3636    writeln!(
3637        code,
3638        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3639    )?;
3640    writeln!(
3641        code,
3642        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3643    )?;
3644    writeln!(code)?;
3645    writeln!(
3646        code,
3647        "                // Attention using Q from qkv_buf (offset 0)"
3648    )?;
3649    writeln!(
3650        code,
3651        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3652    )?;
3653    writeln!(
3654        code,
3655        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3656    )?;
3657    writeln!(
3658        code,
3659        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3660    )?;
3661    writeln!(
3662        code,
3663        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
3664    )?;
3665    writeln!(
3666        code,
3667        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3668    )?;
3669    writeln!(
3670        code,
3671        "                // Fused gate+up matmul: single dispatch for both projections"
3672    )?;
3673    writeln!(
3674        code,
3675        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3676    )?;
3677    writeln!(
3678        code,
3679        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3680    )?;
3681    writeln!(
3682        code,
3683        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3684    )?;
3685    writeln!(
3686        code,
3687        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3688    )?;
3689    writeln!(code, "            }}")?;
3690    writeln!(code)?;
3691
3692    // 3. Final RMS norm + logits
3693    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
3694    writeln!(
3695        code,
3696        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
3697    )?;
3698    writeln!(
3699        code,
3700        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3701    )?;
3702    writeln!(code)?;
3703    writeln!(code, "            enc.end_encoding();")?;
3704    writeln!(code, "        }}")?;
3705    writeln!(code)?;
3706
3707    // 5. Single commit + wait, then read back logits
3708    writeln!(
3709        code,
3710        "        // 5. Commit all GPU work and wait for completion"
3711    )?;
3712    writeln!(code, "        cmd.commit();")?;
3713    writeln!(code, "        cmd.wait_until_completed();")?;
3714    writeln!(code)?;
3715    writeln!(code, "        // 6. Read back logits from GPU")?;
3716    writeln!(code, "        let logits = unsafe {{")?;
3717    writeln!(
3718        code,
3719        "            let ptr = self.logits_buf.contents() as *const f32;"
3720    )?;
3721    writeln!(
3722        code,
3723        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3724    )?;
3725    writeln!(code, "        }};")?;
3726    writeln!(code)?;
3727    writeln!(code, "        self.pos += 1;")?;
3728    writeln!(code, "        logits")?;
3729    writeln!(code, "    }}")?;
3730    writeln!(code)?;
3731
3732    // ── forward_profile: instrumented forward with per-operation timing ──
3733    writeln!(
3734        code,
3735        "    /// Profiling forward pass that prints per-stage GPU timing."
3736    )?;
3737    writeln!(code, "    ///")?;
3738    writeln!(
3739        code,
3740        "    /// Each stage is committed and waited on separately so that GPU timestamps"
3741    )?;
3742    writeln!(
3743        code,
3744        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
3745    )?;
3746    writeln!(
3747        code,
3748        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
3749    )?;
3750    writeln!(
3751        code,
3752        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
3753    )?;
3754    writeln!(code, "        use std::time::Instant;")?;
3755    writeln!(code)?;
3756    writeln!(
3757        code,
3758        "        // Wait for any pending prefill command buffer"
3759    )?;
3760    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3761    writeln!(code, "            prev.wait_until_completed();")?;
3762    writeln!(code, "        }}")?;
3763    writeln!(code)?;
3764    writeln!(code, "        let pos = self.pos;")?;
3765    writeln!(code)?;
3766
3767    // Stage: embedding (CPU, no GPU)
3768    writeln!(
3769        code,
3770        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
3771    )?;
3772    writeln!(code, "        let t_embed = Instant::now();")?;
3773    writeln!(code, "        unsafe {{")?;
3774    writeln!(
3775        code,
3776        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
3777    )?;
3778    writeln!(
3779        code,
3780        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3781    )?;
3782    writeln!(
3783        code,
3784        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
3785    )?;
3786    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
3787    writeln!(
3788        code,
3789        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3790    )?;
3791    writeln!(code, "                hidden_ptr,")?;
3792    writeln!(code, "                HIDDEN_SIZE,")?;
3793    writeln!(code, "            );")?;
3794    writeln!(
3795        code,
3796        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3797    )?;
3798    writeln!(code, "        }}")?;
3799    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
3800    writeln!(code)?;
3801
3802    // Stage: Transformer layers (all together on GPU)
3803    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
3804    writeln!(code, "        let t_layers = Instant::now();")?;
3805    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3806    writeln!(code, "        {{")?;
3807    writeln!(
3808        code,
3809        "            let enc = cmd.new_compute_command_encoder();"
3810    )?;
3811    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3812    writeln!(
3813        code,
3814        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3815    )?;
3816    writeln!(
3817        code,
3818        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3819    )?;
3820    if config.qkv_bias {
3821        writeln!(
3822            code,
3823            "                self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
3824        )?;
3825    }
3826    writeln!(
3827        code,
3828        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3829    )?;
3830    writeln!(
3831        code,
3832        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3833    )?;
3834    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
3835    writeln!(
3836        code,
3837        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3838    )?;
3839    writeln!(
3840        code,
3841        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3842    )?;
3843    writeln!(
3844        code,
3845        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3846    )?;
3847    writeln!(
3848        code,
3849        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3850    )?;
3851    writeln!(
3852        code,
3853        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3854    )?;
3855    writeln!(
3856        code,
3857        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3858    )?;
3859    writeln!(
3860        code,
3861        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3862    )?;
3863    writeln!(
3864        code,
3865        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3866    )?;
3867    writeln!(
3868        code,
3869        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3870    )?;
3871    writeln!(
3872        code,
3873        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3874    )?;
3875    writeln!(code, "            }}")?;
3876    writeln!(code, "            enc.end_encoding();")?;
3877    writeln!(code, "        }}")?;
3878    writeln!(code, "        cmd.commit();")?;
3879    writeln!(code, "        cmd.wait_until_completed();")?;
3880    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
3881    writeln!(code)?;
3882
3883    // Stage: Final norm + logits
3884    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
3885    writeln!(code, "        let t_logits = Instant::now();")?;
3886    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
3887    writeln!(code, "        {{")?;
3888    writeln!(
3889        code,
3890        "            let enc = cmd2.new_compute_command_encoder();"
3891    )?;
3892    writeln!(
3893        code,
3894        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
3895    )?;
3896    writeln!(
3897        code,
3898        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3899    )?;
3900    writeln!(code, "            enc.end_encoding();")?;
3901    writeln!(code, "        }}")?;
3902    writeln!(code, "        cmd2.commit();")?;
3903    writeln!(code, "        cmd2.wait_until_completed();")?;
3904    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
3905    writeln!(code)?;
3906
3907    // Print profile results
3908    writeln!(
3909        code,
3910        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
3911    )?;
3912    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
3913    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
3914    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
3915    writeln!(
3916        code,
3917        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
3918    )?;
3919    writeln!(code)?;
3920
3921    // Read back logits
3922    writeln!(code, "        let logits = unsafe {{")?;
3923    writeln!(
3924        code,
3925        "            let ptr = self.logits_buf.contents() as *const f32;"
3926    )?;
3927    writeln!(
3928        code,
3929        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3930    )?;
3931    writeln!(code, "        }};")?;
3932    writeln!(code)?;
3933    writeln!(code, "        self.pos += 1;")?;
3934    writeln!(code, "        logits")?;
3935    writeln!(code, "    }}")?;
3936    writeln!(code)?;
3937
3938    // ── forward_prefill: single-token async forward (backward compat) ──
3939    writeln!(
3940        code,
3941        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
3942    )?;
3943    writeln!(code, "    ///")?;
3944    writeln!(
3945        code,
3946        "    /// Commits the command buffer without waiting, enabling double-buffered"
3947    )?;
3948    writeln!(
3949        code,
3950        "    /// execution: GPU processes token N while CPU encodes token N+1."
3951    )?;
3952    writeln!(
3953        code,
3954        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
3955    )?;
3956    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
3957    writeln!(code, "    }}")?;
3958    writeln!(code)?;
3959
3960    // ── forward_prefill_batch: batched prefill for multiple tokens ──
3961    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
3962    let batch_matmul_fn = if is_q8 {
3963        "dispatch_matmul_q8_batch"
3964    } else if is_q4 {
3965        "dispatch_matmul_q4_batch"
3966    } else {
3967        "dispatch_matmul_batch"
3968    };
3969
3970    writeln!(
3971        code,
3972        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
3973    )?;
3974    writeln!(code, "    ///")?;
3975    writeln!(
3976        code,
3977        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
3978    )?;
3979    writeln!(
3980        code,
3981        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
3982    )?;
3983    writeln!(
3984        code,
3985        "    /// This provides significant speedup during prompt prefill."
3986    )?;
3987    writeln!(
3988        code,
3989        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
3990    )?;
3991    writeln!(code, "        let m = tokens.len().min(MAX_BATCH_SIZE);")?;
3992    writeln!(code, "        if m == 0 {{ return; }}")?;
3993    writeln!(code, "        let start_pos = self.pos;")?;
3994    writeln!(code)?;
3995    writeln!(code, "        // Wait for any pending command buffer")?;
3996    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3997    writeln!(code, "            prev.wait_until_completed();")?;
3998    writeln!(code, "        }}")?;
3999    writeln!(code)?;
4000
4001    // Upload token IDs and positions to GPU
4002    writeln!(
4003        code,
4004        "        // Upload token IDs and positions to GPU buffers"
4005    )?;
4006    writeln!(code, "        unsafe {{")?;
4007    writeln!(
4008        code,
4009        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4010    )?;
4011    writeln!(
4012        code,
4013        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4014    )?;
4015    writeln!(code, "            for i in 0..m {{")?;
4016    writeln!(code, "                *tok_ptr.add(i) = tokens[i];")?;
4017    writeln!(
4018        code,
4019        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
4020    )?;
4021    writeln!(code, "            }}")?;
4022    writeln!(code, "        }}")?;
4023    writeln!(code)?;
4024
4025    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
4026    writeln!(code, "        {{")?;
4027    writeln!(
4028        code,
4029        "            let enc = cmd.new_compute_command_encoder();"
4030    )?;
4031    writeln!(code)?;
4032
4033    // 1. Batch embedding lookup
4034    writeln!(
4035        code,
4036        "            // 1. Batch embedding lookup: copy all token embeddings at once"
4037    )?;
4038    writeln!(
4039        code,
4040        "            self.dispatch_copy_embedding_batch(&enc, m);"
4041    )?;
4042    // Copy batch_hidden -> batch_residual
4043    writeln!(
4044        code,
4045        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4046    )?;
4047    writeln!(code)?;
4048
4049    // 2. Transformer layers
4050    writeln!(code, "            // 2. Transformer layers")?;
4051    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
4052    writeln!(code)?;
4053
4054    // Batch RMS norm: residual -> hidden (batched)
4055    writeln!(
4056        code,
4057        "                // Batch RMS norm: batch_residual -> batch_hidden"
4058    )?;
4059    writeln!(
4060        code,
4061        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4062    )?;
4063
4064    // Batch QKV matmul
4065    writeln!(
4066        code,
4067        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4068    )?;
4069    writeln!(
4070        code,
4071        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4072    )?;
4073    if config.qkv_bias {
4074        writeln!(
4075            code,
4076            "                // Qwen2: broadcast-add QKV bias across all M tokens."
4077        )?;
4078        writeln!(
4079            code,
4080            "                self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4081        )?;
4082    }
4083    writeln!(code)?;
4084
4085    // Fused RoPE on Q+K portions in a single dispatch
4086    let k_float_offset = hidden;
4087    writeln!(
4088        code,
4089        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4090    )?;
4091    writeln!(
4092        code,
4093        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4094    )?;
4095    writeln!(code)?;
4096
4097    // Fused KV cache update: copy both K and V in a single dispatch
4098    let v_float_offset = hidden + kv_dim;
4099    writeln!(
4100        code,
4101        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4102    )?;
4103    writeln!(
4104        code,
4105        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4106    )?;
4107    writeln!(code)?;
4108
4109    // Batched causal attention: ONE dispatch for all M tokens
4110    writeln!(
4111        code,
4112        "                // Batched causal attention: one dispatch for all M tokens"
4113    )?;
4114    writeln!(
4115        code,
4116        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4117    )?;
4118    writeln!(code)?;
4119
4120    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
4121    writeln!(code, "                // Batched O projection")?;
4122    writeln!(
4123        code,
4124        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4125    )?;
4126    writeln!(code)?;
4127
4128    // Batch add: residual += attn_proj for all tokens
4129    writeln!(
4130        code,
4131        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4132    )?;
4133    writeln!(code)?;
4134
4135    // Batch FFN
4136    writeln!(
4137        code,
4138        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4139    )?;
4140    writeln!(
4141        code,
4142        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4143    )?;
4144    writeln!(
4145        code,
4146        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4147    )?;
4148    writeln!(
4149        code,
4150        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4151    )?;
4152    writeln!(
4153        code,
4154        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4155    )?;
4156    writeln!(
4157        code,
4158        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4159    )?;
4160    writeln!(code, "            }}")?;
4161    writeln!(code)?;
4162
4163    // Copy last token's residual to single-token residual_buf for next forward() call
4164    writeln!(
4165        code,
4166        "            // Copy last token's residual to single-token buffer for subsequent forward()"
4167    )?;
4168    writeln!(
4169        code,
4170        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4171    )?;
4172    writeln!(code)?;
4173    writeln!(code, "            enc.end_encoding();")?;
4174    writeln!(code, "        }}")?;
4175    writeln!(code)?;
4176
4177    writeln!(code, "        cmd.commit();")?;
4178    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
4179    writeln!(code, "        self.pos += m;")?;
4180    writeln!(code, "    }}")?;
4181    writeln!(code)?;
4182
4183    // ── reset() — rewind KV cache position for new inference requests ──
4184    writeln!(
4185        code,
4186        "    /// Reset the model state for a new inference request."
4187    )?;
4188    writeln!(code, "    pub fn reset(&mut self) {{")?;
4189    writeln!(code, "        self.pos = 0;")?;
4190    writeln!(code, "        self.prev_cmd = None;")?;
4191    writeln!(code, "    }}")?;
4192    writeln!(code)?;
4193
4194    // ── Private dispatch helpers (all take a shared compute encoder) ──
4195    writeln!(
4196        code,
4197        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
4198    )?;
4199    writeln!(
4200        code,
4201        "    // These methods set pipeline state + buffers + dispatch on an existing"
4202    )?;
4203    writeln!(
4204        code,
4205        "    // encoder, avoiding per-operation encoder creation overhead."
4206    )?;
4207    writeln!(code)?;
4208
4209    // dispatch_rms_norm
4210    writeln!(
4211        code,
4212        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4213    )?;
4214    writeln!(
4215        code,
4216        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4217    )?;
4218    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4219    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4220    writeln!(
4221        code,
4222        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4223    )?;
4224    writeln!(
4225        code,
4226        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
4227    )?;
4228    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4229    writeln!(
4230        code,
4231        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4232    )?;
4233    writeln!(
4234        code,
4235        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4236    )?;
4237    writeln!(
4238        code,
4239        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4240    )?;
4241    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4242    writeln!(
4243        code,
4244        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4245    )?;
4246    writeln!(
4247        code,
4248        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4249    )?;
4250    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4251    writeln!(code, "    }}")?;
4252    writeln!(code)?;
4253
4254    // dispatch_matmul
4255    writeln!(
4256        code,
4257        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4258    )?;
4259    writeln!(
4260        code,
4261        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4262    )?;
4263    writeln!(code, "        let r: u32 = rows as u32;")?;
4264    writeln!(code, "        let c: u32 = cols as u32;")?;
4265    writeln!(
4266        code,
4267        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4268    )?;
4269    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4270    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4271    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4272    writeln!(
4273        code,
4274        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4275    )?;
4276    writeln!(
4277        code,
4278        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4279    )?;
4280    writeln!(
4281        code,
4282        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4283    )?;
4284    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4285    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4286    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4287    writeln!(
4288        code,
4289        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4290    )?;
4291    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4292    writeln!(code, "    }}")?;
4293    writeln!(code)?;
4294
4295    // dispatch_matmul_q8
4296    writeln!(
4297        code,
4298        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4299    )?;
4300    writeln!(
4301        code,
4302        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4303    )?;
4304    writeln!(
4305        code,
4306        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4307    )?;
4308    writeln!(code, "        let r: u32 = rows as u32;")?;
4309    writeln!(code, "        let c: u32 = cols as u32;")?;
4310    writeln!(
4311        code,
4312        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4313    )?;
4314    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4315    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4316    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4317    writeln!(
4318        code,
4319        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4320    )?;
4321    writeln!(
4322        code,
4323        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4324    )?;
4325    writeln!(
4326        code,
4327        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4328    )?;
4329    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4330    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4331    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4332    writeln!(
4333        code,
4334        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4335    )?;
4336    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4337    writeln!(code, "    }}")?;
4338    writeln!(code)?;
4339
4340    // dispatch_matmul_q4
4341    writeln!(
4342        code,
4343        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4344    )?;
4345    writeln!(
4346        code,
4347        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4348    )?;
4349    writeln!(
4350        code,
4351        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4352    )?;
4353    writeln!(code, "        let r: u32 = rows as u32;")?;
4354    writeln!(code, "        let c: u32 = cols as u32;")?;
4355    writeln!(
4356        code,
4357        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4358    )?;
4359    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4360    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4361    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4362    writeln!(
4363        code,
4364        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4365    )?;
4366    writeln!(
4367        code,
4368        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4369    )?;
4370    writeln!(
4371        code,
4372        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4373    )?;
4374    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4375    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4376    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4377    writeln!(
4378        code,
4379        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4380    )?;
4381    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4382    writeln!(code, "    }}")?;
4383    writeln!(code)?;
4384
4385    // dispatch_rope
4386    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4387    writeln!(
4388        code,
4389        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4390    )?;
4391    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4392    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4393    writeln!(code, "        let p: u32 = pos as u32;")?;
4394    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4395    writeln!(
4396        code,
4397        "        let total_pairs = num_heads * (head_dim / 2);"
4398    )?;
4399    writeln!(
4400        code,
4401        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4402    )?;
4403    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4404    writeln!(
4405        code,
4406        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4407    )?;
4408    writeln!(
4409        code,
4410        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4411    )?;
4412    writeln!(
4413        code,
4414        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4415    )?;
4416    writeln!(
4417        code,
4418        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4419    )?;
4420    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4421    writeln!(
4422        code,
4423        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4424    )?;
4425    writeln!(
4426        code,
4427        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4428    )?;
4429    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4430    writeln!(code, "    }}")?;
4431    writeln!(code)?;
4432
4433    // dispatch_rope_offset
4434    writeln!(
4435        code,
4436        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4437    )?;
4438    writeln!(
4439        code,
4440        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4441    )?;
4442    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4443    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4444    writeln!(code, "        let p: u32 = pos as u32;")?;
4445    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4446    writeln!(
4447        code,
4448        "        let total_pairs = num_heads * (head_dim / 2);"
4449    )?;
4450    writeln!(
4451        code,
4452        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4453    )?;
4454    writeln!(
4455        code,
4456        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4457    )?;
4458    writeln!(
4459        code,
4460        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4461    )?;
4462    writeln!(
4463        code,
4464        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4465    )?;
4466    writeln!(
4467        code,
4468        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4469    )?;
4470    writeln!(
4471        code,
4472        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4473    )?;
4474    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4475    writeln!(
4476        code,
4477        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4478    )?;
4479    writeln!(
4480        code,
4481        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4482    )?;
4483    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4484    writeln!(code, "    }}")?;
4485    writeln!(code)?;
4486
4487    // dispatch_attention
4488    writeln!(
4489        code,
4490        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4491    )?;
4492    writeln!(
4493        code,
4494        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4495    )?;
4496    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4497    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4498    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4499    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4500    writeln!(
4501        code,
4502        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4503    )?;
4504    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4505    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4506    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4507    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4508    writeln!(
4509        code,
4510        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4511    )?;
4512    writeln!(
4513        code,
4514        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4515    )?;
4516    writeln!(
4517        code,
4518        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4519    )?;
4520    writeln!(
4521        code,
4522        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4523    )?;
4524    writeln!(
4525        code,
4526        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4527    )?;
4528    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4529    writeln!(
4530        code,
4531        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4532    )?;
4533    writeln!(
4534        code,
4535        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4536    )?;
4537    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4538    writeln!(code, "    }}")?;
4539    writeln!(code)?;
4540
4541    // dispatch_attention_offset
4542    writeln!(
4543        code,
4544        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4545    )?;
4546    writeln!(
4547        code,
4548        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4549    )?;
4550    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4551    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4552    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4553    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4554    writeln!(
4555        code,
4556        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4557    )?;
4558    writeln!(
4559        code,
4560        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
4561    )?;
4562    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4563    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4564    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4565    writeln!(
4566        code,
4567        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4568    )?;
4569    writeln!(
4570        code,
4571        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4572    )?;
4573    writeln!(
4574        code,
4575        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4576    )?;
4577    writeln!(
4578        code,
4579        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4580    )?;
4581    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4582    writeln!(
4583        code,
4584        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4585    )?;
4586    writeln!(
4587        code,
4588        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4589    )?;
4590    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4591    writeln!(code, "    }}")?;
4592    writeln!(code)?;
4593
4594    // dispatch_silu_mul
4595    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
4596    writeln!(
4597        code,
4598        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
4599    )?;
4600    writeln!(code, "        let count: u32 = n as u32;")?;
4601    writeln!(
4602        code,
4603        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
4604    )?;
4605    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
4606    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
4607    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4608    writeln!(
4609        code,
4610        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4611    )?;
4612    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4613    writeln!(
4614        code,
4615        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4616    )?;
4617    writeln!(
4618        code,
4619        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4620    )?;
4621    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4622    writeln!(code, "    }}")?;
4623    writeln!(code)?;
4624
4625    // dispatch_silu_mul_fused
4626    writeln!(
4627        code,
4628        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
4629    )?;
4630    writeln!(
4631        code,
4632        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
4633    )?;
4634    writeln!(
4635        code,
4636        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
4637    )?;
4638    writeln!(code, "        let count: u32 = n as u32;")?;
4639    writeln!(
4640        code,
4641        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
4642    )?;
4643    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
4644    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
4645    writeln!(
4646        code,
4647        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4648    )?;
4649    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4650    writeln!(
4651        code,
4652        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4653    )?;
4654    writeln!(
4655        code,
4656        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4657    )?;
4658    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4659    writeln!(code, "    }}")?;
4660    writeln!(code)?;
4661
4662    // dispatch_copy (simple src -> dst copy via compute kernel)
4663    writeln!(
4664        code,
4665        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
4666    )?;
4667    writeln!(
4668        code,
4669        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
4670    )?;
4671    writeln!(code, "        let n: u32 = count as u32;")?;
4672    writeln!(
4673        code,
4674        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4675    )?;
4676    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4677    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4678    writeln!(
4679        code,
4680        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4681    )?;
4682    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4683    writeln!(
4684        code,
4685        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4686    )?;
4687    writeln!(
4688        code,
4689        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4690    )?;
4691    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4692    writeln!(code, "    }}")?;
4693    writeln!(code)?;
4694
4695    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
4696    writeln!(
4697        code,
4698        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
4699    )?;
4700    writeln!(
4701        code,
4702        "    /// Used for embedding table lookup (copy a specific row)."
4703    )?;
4704    writeln!(
4705        code,
4706        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
4707    )?;
4708    writeln!(code, "        let off: u32 = src_offset as u32;")?;
4709    writeln!(code, "        let n: u32 = count as u32;")?;
4710    writeln!(
4711        code,
4712        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
4713    )?;
4714    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4715    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4716    writeln!(
4717        code,
4718        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
4719    )?;
4720    writeln!(
4721        code,
4722        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4723    )?;
4724    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4725    writeln!(
4726        code,
4727        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4728    )?;
4729    writeln!(
4730        code,
4731        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4732    )?;
4733    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4734    writeln!(code, "    }}")?;
4735    writeln!(code)?;
4736
4737    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
4738    writeln!(
4739        code,
4740        "    /// Dispatch copy from source at byte offset to destination at float offset."
4741    )?;
4742    writeln!(
4743        code,
4744        "    /// Used for KV cache updates from fused QKV buffer."
4745    )?;
4746    writeln!(
4747        code,
4748        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4749    )?;
4750    writeln!(code, "        let n: u32 = count as u32;")?;
4751    writeln!(
4752        code,
4753        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4754    )?;
4755    writeln!(
4756        code,
4757        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
4758    )?;
4759    writeln!(
4760        code,
4761        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4762    )?;
4763    writeln!(
4764        code,
4765        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4766    )?;
4767    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4768    writeln!(
4769        code,
4770        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4771    )?;
4772    writeln!(
4773        code,
4774        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4775    )?;
4776    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4777    writeln!(code, "    }}")?;
4778    writeln!(code)?;
4779
4780    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
4781    writeln!(
4782        code,
4783        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
4784    )?;
4785    writeln!(
4786        code,
4787        "    /// Used for KV cache updates (write to a specific position in the cache)."
4788    )?;
4789    writeln!(
4790        code,
4791        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
4792    )?;
4793    writeln!(code, "        let n: u32 = count as u32;")?;
4794    writeln!(
4795        code,
4796        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4797    )?;
4798    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4799    writeln!(
4800        code,
4801        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
4802    )?;
4803    writeln!(
4804        code,
4805        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4806    )?;
4807    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4808    writeln!(
4809        code,
4810        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4811    )?;
4812    writeln!(
4813        code,
4814        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4815    )?;
4816    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4817    writeln!(code, "    }}")?;
4818    writeln!(code)?;
4819
4820    // dispatch_add_inplace (residual connection, no blit needed)
4821    writeln!(
4822        code,
4823        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
4824    )?;
4825    writeln!(
4826        code,
4827        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
4828    )?;
4829    writeln!(code, "        let count: u32 = n as u32;")?;
4830    writeln!(
4831        code,
4832        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
4833    )?;
4834    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
4835    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
4836    writeln!(
4837        code,
4838        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4839    )?;
4840    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4841    writeln!(
4842        code,
4843        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4844    )?;
4845    writeln!(
4846        code,
4847        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4848    )?;
4849    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4850    writeln!(code, "    }}")?;
4851    writeln!(code)?;
4852
4853    // ── Batched prefill dispatch helpers ──
4854    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
4855    writeln!(code)?;
4856
4857    // dispatch_copy_embedding_batch
4858    writeln!(
4859        code,
4860        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
4861    )?;
4862    writeln!(
4863        code,
4864        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
4865    )?;
4866    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
4867    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4868    writeln!(
4869        code,
4870        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
4871    )?;
4872    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
4873    writeln!(
4874        code,
4875        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
4876    )?;
4877    writeln!(
4878        code,
4879        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
4880    )?;
4881    writeln!(
4882        code,
4883        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
4884    )?;
4885    writeln!(
4886        code,
4887        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4888    )?;
4889    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
4890    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4891    writeln!(
4892        code,
4893        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4894    )?;
4895    writeln!(
4896        code,
4897        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4898    )?;
4899    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4900    writeln!(code, "    }}")?;
4901    writeln!(code)?;
4902
4903    // dispatch_rms_norm_batch
4904    writeln!(
4905        code,
4906        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
4907    )?;
4908    writeln!(
4909        code,
4910        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
4911    )?;
4912    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4913    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4914    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4915    writeln!(
4916        code,
4917        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
4918    )?;
4919    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
4920    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4921    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4922    writeln!(
4923        code,
4924        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4925    )?;
4926    writeln!(
4927        code,
4928        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4929    )?;
4930    writeln!(
4931        code,
4932        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4933    )?;
4934    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4935    writeln!(
4936        code,
4937        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
4938    )?;
4939    writeln!(
4940        code,
4941        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4942    )?;
4943    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4944    writeln!(code, "    }}")?;
4945    writeln!(code)?;
4946
4947    // dispatch_matmul_batch (f32)
4948    writeln!(
4949        code,
4950        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4951    )?;
4952    writeln!(
4953        code,
4954        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4955    )?;
4956    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4957    writeln!(code, "        let r: u32 = rows as u32;")?;
4958    writeln!(code, "        let c: u32 = cols as u32;")?;
4959    writeln!(
4960        code,
4961        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
4962    )?;
4963    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4964    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4965    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4966    writeln!(
4967        code,
4968        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4969    )?;
4970    writeln!(
4971        code,
4972        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4973    )?;
4974    writeln!(
4975        code,
4976        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4977    )?;
4978    writeln!(
4979        code,
4980        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
4981    )?;
4982    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
4983    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4984    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4985    writeln!(
4986        code,
4987        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4988    )?;
4989    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4990    writeln!(code, "    }}")?;
4991    writeln!(code)?;
4992
4993    // dispatch_matmul_q8_batch
4994    writeln!(
4995        code,
4996        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4997    )?;
4998    writeln!(code, "    ///")?;
4999    writeln!(
5000        code,
5001        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5002    )?;
5003    writeln!(
5004        code,
5005        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5006    )?;
5007    writeln!(
5008        code,
5009        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5010    )?;
5011    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5012    writeln!(code, "        let r: u32 = rows as u32;")?;
5013    writeln!(code, "        let c: u32 = cols as u32;")?;
5014    writeln!(
5015        code,
5016        "        // Tile sizes must match the Metal shader constants."
5017    )?;
5018    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
5019    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
5020    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
5021    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
5022    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
5023    writeln!(
5024        code,
5025        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
5026    )?;
5027    writeln!(
5028        code,
5029        "        // Prefer the large 32×32 tile when the problem supports it — halves"
5030    )?;
5031    writeln!(
5032        code,
5033        "        // dispatch count and reuses each weight load across 32 tokens."
5034    )?;
5035    writeln!(
5036        code,
5037        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5038    )?;
5039    writeln!(
5040        code,
5041        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5042    )?;
5043    writeln!(
5044        code,
5045        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
5046    )?;
5047    writeln!(
5048        code,
5049        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5050    )?;
5051    writeln!(
5052        code,
5053        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
5054    )?;
5055    writeln!(
5056        code,
5057        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5058    )?;
5059    writeln!(
5060        code,
5061        "            // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5062    )?;
5063    writeln!(
5064        code,
5065        "            // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5066    )?;
5067    writeln!(
5068        code,
5069        "            // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5070    )?;
5071    writeln!(code, "            let use_h4 = cols >= 2048;")?;
5072    writeln!(code, "            let pipe = if use_h4 {{")?;
5073    writeln!(code, "                if num_tokens >= 256 {{")?;
5074    writeln!(code, "                    &self.matmul_q8_mma32_hh4_pipeline")?;
5075    writeln!(code, "                }} else {{")?;
5076    writeln!(code, "                    &self.matmul_q8_mma32_h4_pipeline")?;
5077    writeln!(code, "                }}")?;
5078    writeln!(code, "            }} else {{")?;
5079    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
5080    writeln!(code, "            }};")?;
5081    writeln!(
5082        code,
5083        "            enc.set_compute_pipeline_state(pipe);"
5084    )?;
5085    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5086    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5087    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5088    writeln!(
5089        code,
5090        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5091    )?;
5092    writeln!(
5093        code,
5094        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5095    )?;
5096    writeln!(
5097        code,
5098        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5099    )?;
5100    writeln!(
5101        code,
5102        "            let row_tgs = rows / MMA32_ROW_TILE;"
5103    )?;
5104    writeln!(
5105        code,
5106        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5107    )?;
5108    writeln!(
5109        code,
5110        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5111    )?;
5112    writeln!(
5113        code,
5114        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5115    )?;
5116    writeln!(
5117        code,
5118        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5119    )?;
5120    writeln!(
5121        code,
5122        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5123    )?;
5124    writeln!(
5125        code,
5126        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5127    )?;
5128    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5129    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5130    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5131    writeln!(
5132        code,
5133        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5134    )?;
5135    writeln!(
5136        code,
5137        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5138    )?;
5139    writeln!(
5140        code,
5141        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5142    )?;
5143    writeln!(
5144        code,
5145        "            let row_tgs = rows / MMA_ROW_TILE;"
5146    )?;
5147    writeln!(
5148        code,
5149        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5150    )?;
5151    writeln!(
5152        code,
5153        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
5154    )?;
5155    writeln!(
5156        code,
5157        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5158    )?;
5159    writeln!(
5160        code,
5161        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5162    )?;
5163    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5164    writeln!(
5165        code,
5166        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5167    )?;
5168    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5169    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5170    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5171    writeln!(
5172        code,
5173        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5174    )?;
5175    writeln!(
5176        code,
5177        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5178    )?;
5179    writeln!(
5180        code,
5181        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5182    )?;
5183    writeln!(
5184        code,
5185        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5186    )?;
5187    writeln!(
5188        code,
5189        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5190    )?;
5191    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5192    writeln!(
5193        code,
5194        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5195    )?;
5196    writeln!(
5197        code,
5198        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5199    )?;
5200    writeln!(code, "        }} else {{")?;
5201    writeln!(
5202        code,
5203        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5204    )?;
5205    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
5206    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
5207    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
5208    writeln!(
5209        code,
5210        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5211    )?;
5212    writeln!(
5213        code,
5214        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5215    )?;
5216    writeln!(
5217        code,
5218        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5219    )?;
5220    writeln!(
5221        code,
5222        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
5223    )?;
5224    writeln!(
5225        code,
5226        "            let num_tg = (row_tgs * num_tokens) as u64;"
5227    )?;
5228    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5229    writeln!(
5230        code,
5231        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
5232    )?;
5233    writeln!(
5234        code,
5235        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5236    )?;
5237    writeln!(code, "        }}")?;
5238    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5239    writeln!(code, "    }}")?;
5240    writeln!(code)?;
5241
5242    // dispatch_matmul_q4_batch
5243    writeln!(
5244        code,
5245        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5246    )?;
5247    writeln!(
5248        code,
5249        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5250    )?;
5251    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5252    writeln!(code, "        let r: u32 = rows as u32;")?;
5253    writeln!(code, "        let c: u32 = cols as u32;")?;
5254    writeln!(
5255        code,
5256        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5257    )?;
5258    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5259    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5260    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5261    writeln!(
5262        code,
5263        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5264    )?;
5265    writeln!(
5266        code,
5267        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5268    )?;
5269    writeln!(
5270        code,
5271        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5272    )?;
5273    writeln!(
5274        code,
5275        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5276    )?;
5277    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5278    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5279    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5280    writeln!(
5281        code,
5282        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5283    )?;
5284    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5285    writeln!(code, "    }}")?;
5286    writeln!(code)?;
5287
5288    // dispatch_add_bias_batch — Qwen2 QKV bias broadcast-add after fused qkv matmul.
5289    if config.qkv_bias {
5290        writeln!(
5291            code,
5292            "    /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5293        )?;
5294        writeln!(
5295            code,
5296            "    fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5297        )?;
5298        writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5299        writeln!(code, "        let r: u32 = rows as u32;")?;
5300        writeln!(
5301            code,
5302            "        enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5303        )?;
5304        writeln!(code, "        enc.set_buffer(0, Some(out), 0);")?;
5305        writeln!(code, "        enc.set_buffer(1, Some(bias), 0);")?;
5306        writeln!(
5307            code,
5308            "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5309        )?;
5310        writeln!(
5311            code,
5312            "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5313        )?;
5314        writeln!(code, "        let total = (num_tokens * rows) as u64;")?;
5315        writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5316        writeln!(
5317            code,
5318            "        let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5319        )?;
5320        writeln!(
5321            code,
5322            "        enc.dispatch_thread_groups(grid_size, tg_size);"
5323        )?;
5324        writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5325        writeln!(code, "    }}")?;
5326        writeln!(code)?;
5327    }
5328
5329    // dispatch_rope_batch
5330    writeln!(
5331        code,
5332        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5333    )?;
5334    writeln!(
5335        code,
5336        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5337    )?;
5338    writeln!(
5339        code,
5340        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5341    )?;
5342    writeln!(
5343        code,
5344        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5345    )?;
5346    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5347    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5348    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5349    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5350    writeln!(
5351        code,
5352        "        let pairs_per_token = num_heads * (head_dim / 2);"
5353    )?;
5354    writeln!(
5355        code,
5356        "        let total_pairs = num_tokens * pairs_per_token;"
5357    )?;
5358    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5359    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5360    // we need to pass the buffer at the right byte offset for each token's data.
5361    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5362    // but our layout is data[token * row_stride + data_float_offset + ...].
5363    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5364    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5365    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5366    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5367    //
5368    // Simplest approach: use the single-token rope kernel for each token in a loop.
5369    // This is still efficient because we're dispatching all within the same command encoder.
5370    writeln!(
5371        code,
5372        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5373    )?;
5374    writeln!(code, "        for t in 0..num_tokens {{")?;
5375    writeln!(
5376        code,
5377        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5378    )?;
5379    writeln!(
5380        code,
5381        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5382    )?;
5383    writeln!(
5384        code,
5385        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5386    )?;
5387    writeln!(
5388        code,
5389        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5390    )?;
5391    writeln!(
5392        code,
5393        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5394    )?;
5395    writeln!(
5396        code,
5397        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5398    )?;
5399    writeln!(
5400        code,
5401        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5402    )?;
5403    writeln!(
5404        code,
5405        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5406    )?;
5407    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5408    writeln!(
5409        code,
5410        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5411    )?;
5412    writeln!(
5413        code,
5414        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5415    )?;
5416    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5417    writeln!(code, "        }}")?;
5418    writeln!(code, "    }}")?;
5419    writeln!(code)?;
5420
5421    // dispatch_silu_mul_fused_batch
5422    writeln!(
5423        code,
5424        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5425    )?;
5426    writeln!(
5427        code,
5428        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5429    )?;
5430    writeln!(code, "        let count: u32 = n as u32;")?;
5431    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5432    writeln!(
5433        code,
5434        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5435    )?;
5436    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5437    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5438    writeln!(
5439        code,
5440        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5441    )?;
5442    writeln!(
5443        code,
5444        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5445    )?;
5446    writeln!(code, "        let total = num_tokens * n;")?;
5447    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5448    writeln!(
5449        code,
5450        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5451    )?;
5452    writeln!(
5453        code,
5454        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5455    )?;
5456    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5457    writeln!(code, "    }}")?;
5458    writeln!(code)?;
5459
5460    // dispatch_add_inplace_batch_n (add n elements in-place)
5461    writeln!(
5462        code,
5463        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5464    )?;
5465    writeln!(
5466        code,
5467        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5468    )?;
5469    writeln!(code, "        let count: u32 = total_n as u32;")?;
5470    writeln!(
5471        code,
5472        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5473    )?;
5474    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5475    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5476    writeln!(
5477        code,
5478        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5479    )?;
5480    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5481    writeln!(
5482        code,
5483        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5484    )?;
5485    writeln!(
5486        code,
5487        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5488    )?;
5489    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5490    writeln!(code, "    }}")?;
5491    writeln!(code)?;
5492
5493    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
5494    writeln!(
5495        code,
5496        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
5497    )?;
5498    writeln!(
5499        code,
5500        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5501    )?;
5502    writeln!(code, "        let n: u32 = count as u32;")?;
5503    writeln!(
5504        code,
5505        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5506    )?;
5507    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5508    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5509    writeln!(
5510        code,
5511        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5512    )?;
5513    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5514    writeln!(
5515        code,
5516        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5517    )?;
5518    writeln!(
5519        code,
5520        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5521    )?;
5522    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5523    writeln!(code, "    }}")?;
5524    writeln!(code)?;
5525
5526    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
5527    writeln!(
5528        code,
5529        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
5530    )?;
5531    writeln!(
5532        code,
5533        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5534    )?;
5535    writeln!(code, "        let n: u32 = count as u32;")?;
5536    writeln!(
5537        code,
5538        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5539    )?;
5540    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5541    writeln!(
5542        code,
5543        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5544    )?;
5545    writeln!(
5546        code,
5547        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5548    )?;
5549    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5550    writeln!(
5551        code,
5552        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5553    )?;
5554    writeln!(
5555        code,
5556        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5557    )?;
5558    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5559    writeln!(code, "    }}")?;
5560    writeln!(code)?;
5561
5562    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
5563    writeln!(
5564        code,
5565        "    /// Copy from src at byte offset to dst at float offset."
5566    )?;
5567    writeln!(
5568        code,
5569        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5570    )?;
5571    writeln!(code, "        let n: u32 = count as u32;")?;
5572    writeln!(
5573        code,
5574        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5575    )?;
5576    writeln!(
5577        code,
5578        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5579    )?;
5580    writeln!(
5581        code,
5582        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5583    )?;
5584    writeln!(
5585        code,
5586        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5587    )?;
5588    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5589    writeln!(
5590        code,
5591        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5592    )?;
5593    writeln!(
5594        code,
5595        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5596    )?;
5597    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5598    writeln!(code, "    }}")?;
5599    writeln!(code)?;
5600
5601    // dispatch_copy_kv_batch
5602    writeln!(
5603        code,
5604        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
5605    )?;
5606    writeln!(
5607        code,
5608        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
5609    )?;
5610    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5611    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
5612    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5613    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
5614    writeln!(code, "        let so: u32 = src_offset as u32;")?;
5615    writeln!(
5616        code,
5617        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
5618    )?;
5619    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5620    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5621    writeln!(
5622        code,
5623        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5624    )?;
5625    writeln!(
5626        code,
5627        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
5628    )?;
5629    writeln!(
5630        code,
5631        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5632    )?;
5633    writeln!(
5634        code,
5635        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
5636    )?;
5637    writeln!(
5638        code,
5639        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
5640    )?;
5641    writeln!(code, "        let total = num_tokens * kv_dim;")?;
5642    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5643    writeln!(
5644        code,
5645        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5646    )?;
5647    writeln!(
5648        code,
5649        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5650    )?;
5651    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5652    writeln!(code, "    }}")?;
5653    writeln!(code)?;
5654
5655    // dispatch_attention_batch
5656    writeln!(
5657        code,
5658        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
5659    )?;
5660    writeln!(
5661        code,
5662        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
5663    )?;
5664    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5665    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5666    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5667    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5668    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5669    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
5670    writeln!(
5671        code,
5672        "        enc.set_compute_pipeline_state(&self.attention_batch_pipeline);"
5673    )?;
5674    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
5675    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5676    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5677    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5678    writeln!(
5679        code,
5680        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5681    )?;
5682    writeln!(
5683        code,
5684        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5685    )?;
5686    writeln!(
5687        code,
5688        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5689    )?;
5690    writeln!(
5691        code,
5692        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5693    )?;
5694    writeln!(
5695        code,
5696        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5697    )?;
5698    writeln!(
5699        code,
5700        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5701    )?;
5702    writeln!(
5703        code,
5704        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
5705    )?;
5706    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5707    writeln!(
5708        code,
5709        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
5710    )?;
5711    writeln!(
5712        code,
5713        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5714    )?;
5715    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5716    writeln!(code, "    }}")?;
5717    writeln!(code)?;
5718
5719    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
5720    writeln!(
5721        code,
5722        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
5723    )?;
5724    writeln!(
5725        code,
5726        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
5727    )?;
5728    writeln!(
5729        code,
5730        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
5731    )?;
5732    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5733    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5734    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
5735    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5736    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5737    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
5738    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5739    writeln!(
5740        code,
5741        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
5742    )?;
5743    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
5744    writeln!(
5745        code,
5746        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5747    )?;
5748    writeln!(
5749        code,
5750        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5751    )?;
5752    writeln!(
5753        code,
5754        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
5755    )?;
5756    writeln!(
5757        code,
5758        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5759    )?;
5760    writeln!(
5761        code,
5762        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5763    )?;
5764    writeln!(
5765        code,
5766        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5767    )?;
5768    writeln!(
5769        code,
5770        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5771    )?;
5772    writeln!(
5773        code,
5774        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
5775    )?;
5776    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5777    writeln!(
5778        code,
5779        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5780    )?;
5781    writeln!(
5782        code,
5783        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5784    )?;
5785    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5786    writeln!(code, "    }}")?;
5787    writeln!(code)?;
5788
5789    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
5790    writeln!(
5791        code,
5792        "    /// Dispatch fused K+V cache copy in one kernel launch."
5793    )?;
5794    writeln!(
5795        code,
5796        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
5797    )?;
5798    writeln!(
5799        code,
5800        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
5801    )?;
5802    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5803    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
5804    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5805    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
5806    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
5807    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
5808    writeln!(
5809        code,
5810        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
5811    )?;
5812    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5813    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
5814    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
5815    writeln!(
5816        code,
5817        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5818    )?;
5819    writeln!(
5820        code,
5821        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
5822    )?;
5823    writeln!(
5824        code,
5825        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5826    )?;
5827    writeln!(
5828        code,
5829        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
5830    )?;
5831    writeln!(
5832        code,
5833        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
5834    )?;
5835    writeln!(
5836        code,
5837        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
5838    )?;
5839    writeln!(
5840        code,
5841        "        let total = num_tokens * kv_dim * 2;  // K + V"
5842    )?;
5843    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5844    writeln!(
5845        code,
5846        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5847    )?;
5848    writeln!(
5849        code,
5850        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5851    )?;
5852    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5853    writeln!(code, "    }}")?;
5854
5855    writeln!(code, "}}")?;
5856    writeln!(code)?;
5857
5858    Ok(())
5859}
5860
5861fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
5862    writeln!(
5863        code,
5864        "// ── Helper functions ──────────────────────────────────"
5865    )?;
5866    writeln!(code)?;
5867    writeln!(
5868        code,
5869        "/// Create a compute pipeline from a named function in the Metal library."
5870    )?;
5871    writeln!(
5872        code,
5873        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
5874    )?;
5875    writeln!(
5876        code,
5877        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
5878    )?;
5879    writeln!(
5880        code,
5881        "    device.new_compute_pipeline_state_with_function(&func)"
5882    )?;
5883    writeln!(
5884        code,
5885        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
5886    )?;
5887    writeln!(code, "}}")?;
5888    writeln!(code)?;
5889
5890    Ok(())
5891}
5892
5893// ---------------------------------------------------------------------------
5894// main.rs generation
5895// ---------------------------------------------------------------------------
5896
5897fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
5898    let _sanitized = sanitize_name(model_name);
5899    let _vocab = config.vocab_size;
5900
5901    let mut code = String::with_capacity(16 * 1024);
5902    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
5903    writeln!(
5904        code,
5905        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
5906    )?;
5907    writeln!(code)?;
5908    writeln!(code, "mod model;")?;
5909    writeln!(code)?;
5910    writeln!(code, "use std::io::Write;")?;
5911    writeln!(code, "use std::time::Instant;")?;
5912    writeln!(code, "use serde::Deserialize;")?;
5913    writeln!(code)?;
5914
5915    // -- main function --
5916    writeln!(code, "fn main() {{")?;
5917    writeln!(
5918        code,
5919        "    let args: Vec<String> = std::env::args().collect();"
5920    )?;
5921    writeln!(code)?;
5922    writeln!(
5923        code,
5924        "    // Detect --serve mode (only requires weights + tokenizer)"
5925    )?;
5926    writeln!(
5927        code,
5928        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
5929    )?;
5930    writeln!(code)?;
5931    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
5932    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
5933    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
5934    writeln!(code, "        std::process::exit(1);")?;
5935    writeln!(code, "    }}")?;
5936    writeln!(code)?;
5937    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
5938    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
5939    writeln!(code, "        std::process::exit(1);")?;
5940    writeln!(code, "    }}")?;
5941    writeln!(code)?;
5942    writeln!(code, "    let weights_path = &args[1];")?;
5943    writeln!(code, "    let tokenizer_path = &args[2];")?;
5944    writeln!(code)?;
5945    writeln!(code, "    // Parse optional flags")?;
5946    writeln!(code, "    let mut max_tokens: usize = 128;")?;
5947    writeln!(code, "    let mut port: u16 = 8080;")?;
5948    writeln!(
5949        code,
5950        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
5951    )?;
5952    writeln!(
5953        code,
5954        "    let profile = args.iter().any(|a| a == \"--profile\");"
5955    )?;
5956    writeln!(code, "    let mut i = 3;")?;
5957    writeln!(code, "    while i < args.len() {{")?;
5958    writeln!(
5959        code,
5960        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
5961    )?;
5962    writeln!(
5963        code,
5964        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
5965    )?;
5966    writeln!(code, "            i += 2;")?;
5967    writeln!(
5968        code,
5969        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
5970    )?;
5971    writeln!(
5972        code,
5973        "            port = args[i + 1].parse().unwrap_or(8080);"
5974    )?;
5975    writeln!(code, "            i += 2;")?;
5976    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
5977    writeln!(code, "            i += 1;")?;
5978    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
5979    writeln!(code, "            i += 1;")?;
5980    writeln!(code, "        }} else {{")?;
5981    writeln!(code, "            i += 1;")?;
5982    writeln!(code, "        }}")?;
5983    writeln!(code, "    }}")?;
5984    writeln!(code)?;
5985
5986    // -- load model (shared by both modes) --
5987    writeln!(
5988        code,
5989        "    // Memory-map weights for zero-copy loading on Apple Silicon"
5990    )?;
5991    writeln!(
5992        code,
5993        "    let weights_file = std::fs::File::open(weights_path)"
5994    )?;
5995    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
5996    writeln!(
5997        code,
5998        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
5999    )?;
6000    writeln!(code)?;
6001    writeln!(code, "    // Load tokenizer")?;
6002    writeln!(
6003        code,
6004        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6005    )?;
6006    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6007    writeln!(code)?;
6008    writeln!(code, "    // Create Metal model")?;
6009    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
6010    writeln!(
6011        code,
6012        "    let mut model = model::MetalModel::new(&weights_mmap);"
6013    )?;
6014    writeln!(code)?;
6015
6016    // -- branch: serve vs CLI --
6017    writeln!(code, "    if serve_mode {{")?;
6018    writeln!(code, "        serve(model, tokenizer, port);")?;
6019    writeln!(code, "    }} else {{")?;
6020    writeln!(code, "        let prompt = &args[3];")?;
6021    writeln!(
6022        code,
6023        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6024    )?;
6025    writeln!(code, "    }}")?;
6026    writeln!(code, "}}")?;
6027    writeln!(code)?;
6028
6029    // -- cli_mode function --
6030    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6031    writeln!(code, "    // Tokenize prompt")?;
6032    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
6033    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6034    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6035    writeln!(code)?;
6036    writeln!(
6037        code,
6038        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6039    )?;
6040    writeln!(
6041        code,
6042        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
6043    )?;
6044    writeln!(
6045        code,
6046        "    // The last token uses synchronous forward() to get logits."
6047    )?;
6048    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
6049    writeln!(code, "    let prefill_start = Instant::now();")?;
6050    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
6051    writeln!(
6052        code,
6053        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6054    )?;
6055    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
6056    writeln!(code, "    }} else {{")?;
6057    writeln!(code, "        model.forward(prompt_tokens[0])")?;
6058    writeln!(code, "    }};")?;
6059    writeln!(
6060        code,
6061        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6062    )?;
6063    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
6064    writeln!(
6065        code,
6066        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6067    )?;
6068    writeln!(
6069        code,
6070        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6071    )?;
6072    writeln!(code)?;
6073    writeln!(code, "    // Generate tokens")?;
6074    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6075    writeln!(code, "    let gen_start = Instant::now();")?;
6076    writeln!(code, "    let mut generated_count: usize = 0;")?;
6077    writeln!(code)?;
6078    writeln!(code, "    for _ in 0..max_tokens {{")?;
6079    writeln!(
6080        code,
6081        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6082    )?;
6083    writeln!(code, "            if !quiet {{")?;
6084    writeln!(code, "                print!(\"{{}}\", text);")?;
6085    writeln!(code, "                std::io::stdout().flush().ok();")?;
6086    writeln!(code, "            }}")?;
6087    writeln!(code, "        }}")?;
6088    writeln!(code, "        generated_count += 1;")?;
6089    writeln!(code)?;
6090    writeln!(
6091        code,
6092        "        // Use profiling forward for first token when --profile is set"
6093    )?;
6094    writeln!(
6095        code,
6096        "        let logits = if profile && generated_count == 1 {{"
6097    )?;
6098    writeln!(code, "            model.forward_profile(next_token)")?;
6099    writeln!(code, "        }} else {{")?;
6100    writeln!(code, "            model.forward(next_token)")?;
6101    writeln!(code, "        }};")?;
6102    writeln!(code, "        next_token = argmax(&logits);")?;
6103    writeln!(code)?;
6104    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
6105    writeln!(code, "        if next_token == 2 {{")?;
6106    writeln!(code, "            break;")?;
6107    writeln!(code, "        }}")?;
6108    writeln!(code)?;
6109    writeln!(
6110        code,
6111        "        // Yield between tokens to reduce sustained GPU thermal load."
6112    )?;
6113    writeln!(
6114        code,
6115        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6116    )?;
6117    writeln!(
6118        code,
6119        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6120    )?;
6121    writeln!(
6122        code,
6123        "        // briefly, providing a micro-break that helps sustain peak throughput."
6124    )?;
6125    writeln!(code, "        std::thread::yield_now();")?;
6126    writeln!(code, "    }}")?;
6127    writeln!(code, "    if !quiet {{")?;
6128    writeln!(code, "        println!();")?;
6129    writeln!(code, "    }}")?;
6130    writeln!(
6131        code,
6132        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6133    )?;
6134    writeln!(
6135        code,
6136        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6137    )?;
6138    writeln!(
6139        code,
6140        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6141    )?;
6142    writeln!(code, "}}")?;
6143    writeln!(code)?;
6144
6145    // -- argmax helper --
6146    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6147    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6148    writeln!(code, "    logits.iter()")?;
6149    writeln!(code, "        .enumerate()")?;
6150    writeln!(
6151        code,
6152        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6153    )?;
6154    writeln!(code, "        .map(|(i, _)| i as u32)")?;
6155    writeln!(code, "        .unwrap_or(0)")?;
6156    writeln!(code, "}}")?;
6157    writeln!(code)?;
6158
6159    // -- Request/Response types for OpenAI API --
6160    writeln!(
6161        code,
6162        "// -----------------------------------------------------------------------"
6163    )?;
6164    writeln!(code, "// OpenAI-compatible API server")?;
6165    writeln!(
6166        code,
6167        "// -----------------------------------------------------------------------"
6168    )?;
6169    writeln!(code)?;
6170    writeln!(code, "#[derive(Deserialize)]")?;
6171    writeln!(code, "struct ChatRequest {{")?;
6172    writeln!(code, "    messages: Vec<ChatMessage>,")?;
6173    writeln!(code, "    #[serde(default)]")?;
6174    writeln!(code, "    stream: Option<bool>,")?;
6175    writeln!(code, "    #[serde(default)]")?;
6176    writeln!(code, "    max_tokens: Option<usize>,")?;
6177    writeln!(code, "    #[serde(default)]")?;
6178    writeln!(code, "    temperature: Option<f32>,")?;
6179    writeln!(code, "    #[serde(default)]")?;
6180    writeln!(code, "    model: Option<String>,")?;
6181    writeln!(code, "}}")?;
6182    writeln!(code)?;
6183    writeln!(code, "#[derive(Deserialize)]")?;
6184    writeln!(code, "struct ChatMessage {{")?;
6185    writeln!(code, "    role: String,")?;
6186    writeln!(code, "    content: String,")?;
6187    writeln!(code, "}}")?;
6188    writeln!(code)?;
6189
6190    // -- format_chat_messages --
6191    writeln!(
6192        code,
6193        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6194    )?;
6195    writeln!(code, "    let mut prompt = String::new();")?;
6196    writeln!(code, "    for msg in messages {{")?;
6197    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6198    writeln!(code, "    }}")?;
6199    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6200    writeln!(code, "    prompt")?;
6201    writeln!(code, "}}")?;
6202    writeln!(code)?;
6203
6204    // -- prefill helper --
6205    writeln!(
6206        code,
6207        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6208    )?;
6209    writeln!(code, "    let len = tokens.len();")?;
6210    writeln!(code, "    if len > 1 {{")?;
6211    writeln!(
6212        code,
6213        "        model.forward_prefill_batch(&tokens[..len - 1]);"
6214    )?;
6215    writeln!(code, "    }}")?;
6216    writeln!(code, "    model.forward(tokens[len - 1])")?;
6217    writeln!(code, "}}")?;
6218    writeln!(code)?;
6219
6220    // -- serve function --
6221    writeln!(
6222        code,
6223        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6224    )?;
6225    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6226    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
6227    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6228    writeln!(
6229        code,
6230        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6231    )?;
6232    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
6233    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
6234    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
6235    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
6236    writeln!(code)?;
6237    writeln!(code, "    for request in server.incoming_requests() {{")?;
6238    writeln!(code, "        let method = request.method().to_string();")?;
6239    writeln!(code, "        let url = request.url().to_string();")?;
6240    writeln!(code)?;
6241    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
6242
6243    // -- POST /v1/chat/completions --
6244    writeln!(
6245        code,
6246        "            (\"POST\", \"/v1/chat/completions\") => {{"
6247    )?;
6248    writeln!(
6249        code,
6250        "                handle_chat_completion(&mut model, &tokenizer, request);"
6251    )?;
6252    writeln!(code, "            }}")?;
6253
6254    // -- GET /v1/models --
6255    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
6256    writeln!(code, "                let body = serde_json::json!({{")?;
6257    writeln!(code, "                    \"object\": \"list\",")?;
6258    writeln!(code, "                    \"data\": [{{")?;
6259    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
6260    writeln!(code, "                        \"object\": \"model\",")?;
6261    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
6262    writeln!(code, "                    }}]")?;
6263    writeln!(code, "                }});")?;
6264    writeln!(
6265        code,
6266        "                let resp = tiny_http::Response::from_string(body.to_string())"
6267    )?;
6268    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6269    writeln!(code, "                request.respond(resp).ok();")?;
6270    writeln!(code, "            }}")?;
6271
6272    // -- GET /health --
6273    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
6274    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6275    writeln!(code, "                request.respond(resp).ok();")?;
6276    writeln!(code, "            }}")?;
6277
6278    // -- 404 --
6279    writeln!(code, "            _ => {{")?;
6280    writeln!(
6281        code,
6282        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
6283    )?;
6284    writeln!(code, "                    .with_status_code(404);")?;
6285    writeln!(code, "                request.respond(resp).ok();")?;
6286    writeln!(code, "            }}")?;
6287    writeln!(code, "        }}")?;
6288    writeln!(code, "    }}")?;
6289    writeln!(code, "}}")?;
6290    writeln!(code)?;
6291
6292    // -- handle_chat_completion --
6293    writeln!(code, "fn handle_chat_completion(")?;
6294    writeln!(code, "    model: &mut model::MetalModel,")?;
6295    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6296    writeln!(code, "    mut request: tiny_http::Request,")?;
6297    writeln!(code, ") {{")?;
6298    writeln!(code, "    // Read request body")?;
6299    writeln!(code, "    let mut body = String::new();")?;
6300    writeln!(
6301        code,
6302        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6303    )?;
6304    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6305    writeln!(code, "            .with_status_code(400);")?;
6306    writeln!(code, "        request.respond(resp).ok();")?;
6307    writeln!(code, "        return;")?;
6308    writeln!(code, "    }}")?;
6309    writeln!(code)?;
6310    writeln!(code, "    // Parse JSON")?;
6311    writeln!(
6312        code,
6313        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6314    )?;
6315    writeln!(code, "        Ok(r) => r,")?;
6316    writeln!(code, "        Err(e) => {{")?;
6317    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6318    writeln!(code, "                .with_status_code(400);")?;
6319    writeln!(code, "            request.respond(resp).ok();")?;
6320    writeln!(code, "            return;")?;
6321    writeln!(code, "        }}")?;
6322    writeln!(code, "    }};")?;
6323    writeln!(code)?;
6324    writeln!(
6325        code,
6326        "    let prompt = format_chat_messages(&req.messages);"
6327    )?;
6328    writeln!(
6329        code,
6330        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6331    )?;
6332    writeln!(code, "        Ok(e) => e,")?;
6333    writeln!(code, "        Err(e) => {{")?;
6334    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6335    writeln!(code, "                .with_status_code(500);")?;
6336    writeln!(code, "            request.respond(resp).ok();")?;
6337    writeln!(code, "            return;")?;
6338    writeln!(code, "        }}")?;
6339    writeln!(code, "    }};")?;
6340    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6341    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6342    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6343    writeln!(
6344        code,
6345        "    let _temperature = req.temperature.unwrap_or(1.0);"
6346    )?;
6347    writeln!(code)?;
6348
6349    // -- Reset KV cache for each request --
6350    writeln!(code, "    model.reset();")?;
6351    writeln!(code)?;
6352
6353    // -- Prefill with timing --
6354    writeln!(code, "    let prefill_start = Instant::now();")?;
6355    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6356    writeln!(
6357        code,
6358        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6359    )?;
6360    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6361    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6362    writeln!(code)?;
6363
6364    writeln!(code, "    if stream {{")?;
6365
6366    // -- SSE streaming response --
6367    writeln!(
6368        code,
6369        "        // SSE streaming: generate tokens and build SSE body"
6370    )?;
6371    writeln!(code, "        let gen_start = Instant::now();")?;
6372    writeln!(code, "        let mut generated_count: usize = 0;")?;
6373    writeln!(code, "        let mut sse_body = String::new();")?;
6374    writeln!(code, "        for _ in 0..max_tokens {{")?;
6375    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6376    writeln!(
6377        code,
6378        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6379    )?;
6380    writeln!(
6381        code,
6382        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6383    )?;
6384    writeln!(
6385        code,
6386        "                // escaped includes surrounding quotes, strip them"
6387    )?;
6388    writeln!(
6389        code,
6390        "                let inner = &escaped[1..escaped.len()-1];"
6391    )?;
6392    writeln!(code, "                sse_body.push_str(&format!(")?;
6393    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6394    writeln!(code, "                    inner")?;
6395    writeln!(code, "                ));")?;
6396    writeln!(code, "            }}")?;
6397    writeln!(code, "            generated_count += 1;")?;
6398    writeln!(code, "            let logits = model.forward(next_token);")?;
6399    writeln!(code, "            next_token = argmax(&logits);")?;
6400    writeln!(code, "        }}")?;
6401    writeln!(
6402        code,
6403        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6404    )?;
6405    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6406    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
6407    writeln!(code)?;
6408    writeln!(
6409        code,
6410        "        // Final chunk with finish_reason, timing, and DONE sentinel"
6411    )?;
6412    writeln!(code, "        sse_body.push_str(&format!(")?;
6413    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6414    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6415    writeln!(code, "        ));")?;
6416    writeln!(code)?;
6417    writeln!(
6418        code,
6419        "        let resp = tiny_http::Response::from_string(sse_body)"
6420    )?;
6421    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6422    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6423    writeln!(code, "        request.respond(resp).ok();")?;
6424
6425    writeln!(code, "    }} else {{")?;
6426
6427    // -- Non-streaming response --
6428    writeln!(
6429        code,
6430        "        // Non-streaming: generate all tokens, return JSON"
6431    )?;
6432    writeln!(code, "        let gen_start = Instant::now();")?;
6433    writeln!(code, "        let mut generated_count: usize = 0;")?;
6434    writeln!(code, "        let mut generated = String::new();")?;
6435    writeln!(code, "        for _ in 0..max_tokens {{")?;
6436    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6437    writeln!(
6438        code,
6439        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6440    )?;
6441    writeln!(code, "                generated.push_str(&text);")?;
6442    writeln!(code, "            }}")?;
6443    writeln!(code, "            generated_count += 1;")?;
6444    writeln!(code, "            let logits = model.forward(next_token);")?;
6445    writeln!(code, "            next_token = argmax(&logits);")?;
6446    writeln!(code, "        }}")?;
6447    writeln!(
6448        code,
6449        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6450    )?;
6451    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6452    writeln!(code)?;
6453    writeln!(code, "        let resp_json = serde_json::json!({{")?;
6454    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
6455    writeln!(code, "            \"object\": \"chat.completion\",")?;
6456    writeln!(code, "            \"choices\": [{{")?;
6457    writeln!(code, "                \"index\": 0,")?;
6458    writeln!(code, "                \"message\": {{")?;
6459    writeln!(code, "                    \"role\": \"assistant\",")?;
6460    writeln!(code, "                    \"content\": generated")?;
6461    writeln!(code, "                }},")?;
6462    writeln!(code, "                \"finish_reason\": \"stop\"")?;
6463    writeln!(code, "            }}],")?;
6464    writeln!(code, "            \"usage\": {{")?;
6465    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
6466    writeln!(
6467        code,
6468        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
6469    )?;
6470    writeln!(
6471        code,
6472        "                \"generation_tokens\": generated_count,"
6473    )?;
6474    writeln!(
6475        code,
6476        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
6477    )?;
6478    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
6479    writeln!(code, "            }}")?;
6480    writeln!(code, "        }});")?;
6481    writeln!(
6482        code,
6483        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
6484    )?;
6485    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6486    writeln!(code, "        request.respond(resp).ok();")?;
6487    writeln!(code, "    }}")?;
6488    writeln!(code, "}}")?;
6489
6490    Ok(code)
6491}
6492
6493// ---------------------------------------------------------------------------
6494// Tests
6495// ---------------------------------------------------------------------------
6496
6497#[cfg(test)]
6498mod tests {
6499    use super::*;
6500    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
6501
6502    fn minimal_config() -> ModelConfig {
6503        ModelConfig {
6504            architecture: Architecture::Llama,
6505            hidden_size: 64,
6506            intermediate_size: 128,
6507            num_layers: 2,
6508            num_attention_heads: 4,
6509            num_kv_heads: 4,
6510            head_dim: 16,
6511            vocab_size: 256,
6512            max_seq_len: 512,
6513            rms_norm_eps: 1e-5,
6514            rope_theta: 10000.0,
6515            dtype: DType::F32,
6516            sliding_window_size: None,
6517            qkv_bias: false,
6518        }
6519    }
6520
6521    fn minimal_graph() -> Graph {
6522        Graph::new("test-metal").with_config(minimal_config())
6523    }
6524
6525    #[test]
6526    fn generate_metal_project_creates_files() {
6527        let dir = tempfile::tempdir().unwrap();
6528        let graph = minimal_graph();
6529        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
6530
6531        assert!(
6532            dir.path().join("Cargo.toml").exists(),
6533            "Cargo.toml should be created"
6534        );
6535        assert!(
6536            dir.path().join("src/model.rs").exists(),
6537            "src/model.rs should be created"
6538        );
6539        assert!(
6540            dir.path().join("src/main.rs").exists(),
6541            "src/main.rs should be created"
6542        );
6543        assert!(
6544            dir.path().join("shaders/kernels.metal").exists(),
6545            "shaders/kernels.metal should be created"
6546        );
6547    }
6548
6549    #[test]
6550    fn generated_cargo_toml_has_metal_dep() {
6551        let toml = generate_cargo_toml("my-model");
6552        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
6553        assert!(
6554            toml.contains("tokenizers"),
6555            "Cargo.toml should depend on tokenizers"
6556        );
6557        assert!(
6558            toml.contains("memmap2"),
6559            "Cargo.toml should depend on memmap2"
6560        );
6561        assert!(toml.contains("half"), "Cargo.toml should depend on half");
6562    }
6563
6564    #[test]
6565    fn generated_model_rs_contains_metal_code() {
6566        let config = minimal_config();
6567        let model_rs = generate_model_rs(&config).unwrap();
6568
6569        assert!(
6570            model_rs.contains("pub struct MetalModel"),
6571            "model.rs should define MetalModel struct"
6572        );
6573        assert!(
6574            model_rs.contains("matmul_pipeline: ComputePipelineState"),
6575            "MetalModel should have matmul_pipeline field"
6576        );
6577        assert!(
6578            model_rs.contains("Device::system_default()"),
6579            "model.rs should use Metal device"
6580        );
6581        assert!(
6582            model_rs.contains("new_library_with_source"),
6583            "model.rs should compile Metal shaders"
6584        );
6585        assert!(
6586            model_rs.contains("fn new(weights: &[u8])"),
6587            "MetalModel should implement new()"
6588        );
6589        assert!(
6590            model_rs.contains("fn forward(&mut self, token_id: u32)"),
6591            "MetalModel should implement forward()"
6592        );
6593    }
6594
6595    #[test]
6596    fn generated_shaders_contain_kernel_names() {
6597        let shaders = generate_metal_shaders(&minimal_config());
6598
6599        assert!(
6600            shaders.contains("kernel void matmul_vec"),
6601            "shaders should contain matmul_vec kernel"
6602        );
6603        assert!(
6604            shaders.contains("kernel void rms_norm"),
6605            "shaders should contain rms_norm kernel"
6606        );
6607        assert!(
6608            shaders.contains("kernel void rope"),
6609            "shaders should contain rope kernel"
6610        );
6611        assert!(
6612            shaders.contains("kernel void softmax"),
6613            "shaders should contain softmax kernel"
6614        );
6615        assert!(
6616            shaders.contains("kernel void silu_mul("),
6617            "shaders should contain silu_mul kernel"
6618        );
6619        assert!(
6620            shaders.contains("kernel void silu_mul_fused"),
6621            "shaders should contain silu_mul_fused kernel"
6622        );
6623        assert!(
6624            shaders.contains("kernel void elementwise_add"),
6625            "shaders should contain elementwise_add kernel"
6626        );
6627        assert!(
6628            shaders.contains("kernel void attention"),
6629            "shaders should contain attention kernel"
6630        );
6631        assert!(
6632            shaders.contains("kernel void add_inplace"),
6633            "shaders should contain add_inplace kernel"
6634        );
6635        assert!(
6636            shaders.contains("kernel void copy_buffer"),
6637            "shaders should contain copy_buffer kernel"
6638        );
6639        assert!(
6640            shaders.contains("kernel void copy_offset"),
6641            "shaders should contain copy_offset kernel"
6642        );
6643    }
6644
6645    #[test]
6646    fn generated_shaders_use_simdgroup_features() {
6647        let shaders = generate_metal_shaders(&minimal_config());
6648
6649        assert!(
6650            shaders.contains("threadgroup_barrier"),
6651            "shaders should use threadgroup barriers"
6652        );
6653        assert!(
6654            shaders.contains("threadgroup float"),
6655            "shaders should use threadgroup shared memory"
6656        );
6657        assert!(
6658            shaders.contains("thread_index_in_threadgroup"),
6659            "shaders should use threadgroup indexing"
6660        );
6661        assert!(
6662            shaders.contains("simd_sum"),
6663            "shaders should use simd_sum for warp-level reduction"
6664        );
6665        assert!(
6666            shaders.contains("simd_max"),
6667            "attention kernel should use simd_max for cooperative softmax"
6668        );
6669        assert!(
6670            shaders.contains("thread_index_in_simdgroup"),
6671            "shaders should use simdgroup lane indexing"
6672        );
6673        assert!(
6674            shaders.contains("simdgroup_index_in_threadgroup"),
6675            "shaders should use simdgroup indexing within threadgroup"
6676        );
6677        assert!(
6678            shaders.contains("float4"),
6679            "matmul_vec should use float4 vectorized loads"
6680        );
6681    }
6682
6683    #[test]
6684    fn generated_main_rs_has_tokenizer_usage() {
6685        let config = minimal_config();
6686        let main_rs = generate_main_rs("test-model", &config).unwrap();
6687
6688        assert!(
6689            main_rs.contains("tokenizers::Tokenizer"),
6690            "main.rs should use tokenizers crate"
6691        );
6692        assert!(
6693            main_rs.contains("MetalModel::new"),
6694            "main.rs should call MetalModel::new"
6695        );
6696        assert!(
6697            main_rs.contains("model.forward"),
6698            "main.rs should call model.forward"
6699        );
6700        assert!(
6701            main_rs.contains("memmap2"),
6702            "main.rs should use memmap2 for zero-copy weight loading"
6703        );
6704    }
6705
6706    #[test]
6707    fn missing_config_returns_error() {
6708        let dir = tempfile::tempdir().unwrap();
6709        let graph = Graph::new("no-config");
6710        let result = generate_metal_project(&graph, dir.path(), "fail");
6711        assert!(
6712            matches!(result, Err(MetalCodegenError::MissingConfig)),
6713            "should fail with MissingConfig when graph has no config"
6714        );
6715    }
6716
6717    #[test]
6718    fn sanitize_name_works() {
6719        assert_eq!(sanitize_name("My Model!"), "my-model");
6720        assert_eq!(sanitize_name("test_model"), "test-model");
6721        assert_eq!(sanitize_name("simple"), "simple");
6722    }
6723
6724    #[test]
6725    fn generated_forward_uses_single_command_buffer() {
6726        let config = minimal_config();
6727        let model_rs = generate_model_rs(&config).unwrap();
6728
6729        // The forward function should create exactly one command buffer.
6730        // Use the exact signature to avoid matching forward_prefill/forward_profile.
6731        let forward_start = model_rs
6732            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6733            .unwrap();
6734        let forward_body = &model_rs[forward_start..];
6735        // End at the next pub/private method
6736        let forward_end = forward_body
6737            .find("\n    pub fn forward_profile")
6738            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
6739            .or_else(|| forward_body.find("\n    fn dispatch_"))
6740            .unwrap_or(forward_body.len());
6741        let forward_code = &forward_body[..forward_end];
6742
6743        // Should have exactly one new_command_buffer call
6744        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
6745        assert_eq!(
6746            cmd_buf_count, 1,
6747            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
6748        );
6749
6750        // Should have exactly one commit call
6751        let commit_count = forward_code.matches("cmd.commit()").count();
6752        assert_eq!(
6753            commit_count, 1,
6754            "forward() should commit exactly once, found {commit_count}"
6755        );
6756
6757        // Should wait: once for cmd + possibly once for prev_cmd drain
6758        let wait_count = forward_code.matches("wait_until_completed()").count();
6759        assert!(
6760            wait_count >= 1 && wait_count <= 2,
6761            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
6762        );
6763    }
6764
6765    #[test]
6766    fn generated_model_has_preallocated_working_buffers() {
6767        let config = minimal_config();
6768        let model_rs = generate_model_rs(&config).unwrap();
6769
6770        for buf_name in &[
6771            "normed_buf",
6772            "qkv_buf",
6773            "attn_out_buf",
6774            "attn_proj_buf",
6775            "gate_up_buf",
6776            "ffn_hidden_buf",
6777            "ffn_out_buf",
6778            "add_tmp_buf",
6779        ] {
6780            assert!(
6781                model_rs.contains(&format!("{buf_name}: Buffer")),
6782                "MetalModel should have pre-allocated {buf_name} field"
6783            );
6784        }
6785    }
6786
6787    #[test]
6788    fn generated_dispatch_helpers_take_compute_encoder_ref() {
6789        let config = minimal_config();
6790        let model_rs = generate_model_rs(&config).unwrap();
6791
6792        for method in &[
6793            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
6794            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
6795            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
6796            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
6797            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
6798            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
6799            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
6800            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
6801            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
6802            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
6803            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
6804            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
6805            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
6806        ] {
6807            assert!(
6808                model_rs.contains(method),
6809                "model.rs should contain dispatch helper: {method}"
6810            );
6811        }
6812    }
6813
6814    #[test]
6815    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
6816        let config = minimal_config();
6817        let model_rs = generate_model_rs(&config).unwrap();
6818
6819        // Find dispatch helpers section and check none create their own encoders
6820        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
6821        let helpers_code = &model_rs[helpers_start..];
6822
6823        // None of the dispatch_ helpers should call new_command_buffer
6824        assert!(
6825            !helpers_code.contains("self.queue.new_command_buffer()"),
6826            "dispatch helpers should not create their own command buffers"
6827        );
6828
6829        // None should create their own compute encoders
6830        assert!(
6831            !helpers_code.contains("new_compute_command_encoder()"),
6832            "dispatch helpers should not create their own compute encoders"
6833        );
6834
6835        // None should call end_encoding
6836        assert!(
6837            !helpers_code.contains("end_encoding()"),
6838            "dispatch helpers should not call end_encoding"
6839        );
6840
6841        // None should call commit or wait
6842        assert!(
6843            !helpers_code.contains(".commit()"),
6844            "dispatch helpers should not commit command buffers"
6845        );
6846        assert!(
6847            !helpers_code.contains("wait_until_completed"),
6848            "dispatch helpers should not wait on command buffers"
6849        );
6850    }
6851
6852    #[test]
6853    fn generated_forward_batches_compute_encoders() {
6854        let config = minimal_config();
6855        let model_rs = generate_model_rs(&config).unwrap();
6856
6857        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
6858        let forward_start = model_rs
6859            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6860            .unwrap();
6861        let forward_body = &model_rs[forward_start..];
6862        let forward_end = forward_body
6863            .find("\n    pub fn forward_profile")
6864            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
6865            .or_else(|| forward_body.find("\n    fn dispatch_"))
6866            .unwrap_or(forward_body.len());
6867        let forward_code = &forward_body[..forward_end];
6868
6869        // Forward should not allocate new buffers
6870        assert!(
6871            !forward_code.contains("device.new_buffer"),
6872            "forward() should not allocate new buffers per call"
6873        );
6874
6875        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
6876        // Copy operations use compute copy kernels instead of blit encoders.
6877        let compute_encoder_count = forward_code
6878            .matches("new_compute_command_encoder()")
6879            .count();
6880        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
6881
6882        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
6883        assert_eq!(
6884            compute_encoder_count, 1,
6885            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
6886        );
6887        assert_eq!(
6888            blit_encoder_count, 0,
6889            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
6890        );
6891    }
6892
6893    #[test]
6894    fn generated_forward_uses_add_inplace() {
6895        let config = minimal_config();
6896        let model_rs = generate_model_rs(&config).unwrap();
6897
6898        // Should use in-place add (no blit copy-back needed)
6899        assert!(
6900            model_rs.contains("dispatch_add_inplace"),
6901            "forward() should use dispatch_add_inplace for residual connections"
6902        );
6903        assert!(
6904            model_rs.contains("add_inplace_pipeline"),
6905            "MetalModel should have add_inplace_pipeline"
6906        );
6907    }
6908
6909    fn minimal_q8_config() -> ModelConfig {
6910        ModelConfig {
6911            architecture: Architecture::Llama,
6912            hidden_size: 64,
6913            intermediate_size: 128,
6914            num_layers: 2,
6915            num_attention_heads: 4,
6916            num_kv_heads: 4,
6917            head_dim: 16,
6918            vocab_size: 256,
6919            max_seq_len: 512,
6920            rms_norm_eps: 1e-5,
6921            rope_theta: 10000.0,
6922            dtype: DType::Q8_0,
6923            sliding_window_size: None,
6924            qkv_bias: false,
6925        }
6926    }
6927
6928    #[test]
6929    fn generated_shaders_contain_q8_kernel() {
6930        let shaders = generate_metal_shaders(&minimal_config());
6931
6932        assert!(
6933            shaders.contains("kernel void matmul_vec_q8"),
6934            "shaders should contain matmul_vec_q8 kernel"
6935        );
6936        assert!(
6937            shaders.contains("device const uchar* matrix"),
6938            "matmul_vec_q8 should accept raw Q8_0 bytes"
6939        );
6940        assert!(
6941            shaders.contains("packed_short4"),
6942            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
6943        );
6944        assert!(
6945            shaders.contains("as_type<char2>"),
6946            "matmul_vec_q8 should bitcast short lanes to char2"
6947        );
6948        assert!(
6949            shaders.contains("device const half*"),
6950            "matmul_vec_q8 should read f16 scale via half pointer"
6951        );
6952    }
6953
6954    #[test]
6955    fn generated_model_uses_fused_qkv_projections() {
6956        let config = minimal_config();
6957        let model_rs = generate_model_rs(&config).unwrap();
6958
6959        // Should have fused QKV weight in layer buffers
6960        assert!(
6961            model_rs.contains("qkv_weight: Buffer"),
6962            "LayerBuffers should have fused qkv_weight field"
6963        );
6964        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
6965        assert!(
6966            !model_rs.contains("    q_weight: Buffer"),
6967            "LayerBuffers should not have separate q_weight field"
6968        );
6969        assert!(
6970            !model_rs.contains("    k_weight: Buffer"),
6971            "LayerBuffers should not have separate k_weight field"
6972        );
6973        assert!(
6974            !model_rs.contains("    v_weight: Buffer"),
6975            "LayerBuffers should not have separate v_weight field"
6976        );
6977
6978        // Should have fused gate_up_weight
6979        assert!(
6980            model_rs.contains("gate_up_weight: Buffer"),
6981            "LayerBuffers should have fused gate_up_weight field"
6982        );
6983        // Should NOT have separate gate/up weight fields
6984        assert!(
6985            !model_rs.contains("    gate_weight: Buffer"),
6986            "LayerBuffers should not have separate gate_weight field"
6987        );
6988        assert!(
6989            !model_rs.contains("    up_weight: Buffer"),
6990            "LayerBuffers should not have separate up_weight field"
6991        );
6992
6993        // Should have fused working buffers
6994        assert!(
6995            model_rs.contains("qkv_buf: Buffer"),
6996            "MetalModel should have fused qkv_buf"
6997        );
6998        assert!(
6999            model_rs.contains("gate_up_buf: Buffer"),
7000            "MetalModel should have fused gate_up_buf"
7001        );
7002
7003        // Forward pass should use fused dispatch
7004        assert!(
7005            model_rs.contains("dispatch_silu_mul_fused"),
7006            "forward pass should use dispatch_silu_mul_fused"
7007        );
7008        assert!(
7009            model_rs.contains("dispatch_rope_offset"),
7010            "forward pass should use dispatch_rope_offset for fused QKV"
7011        );
7012        assert!(
7013            model_rs.contains("dispatch_attention_offset"),
7014            "forward pass should use dispatch_attention_offset for fused QKV"
7015        );
7016    }
7017
7018    #[test]
7019    fn q8_model_has_matmul_q8_pipeline() {
7020        let config = minimal_q8_config();
7021        let model_rs = generate_model_rs(&config).unwrap();
7022
7023        assert!(
7024            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7025            "MetalModel should have matmul_q8_pipeline field"
7026        );
7027        assert!(
7028            model_rs.contains("matmul_q8_pipeline,"),
7029            "MetalModel Self should include matmul_q8_pipeline"
7030        );
7031    }
7032
7033    #[test]
7034    fn q8_model_uses_dispatch_matmul_q8() {
7035        let config = minimal_q8_config();
7036        let model_rs = generate_model_rs(&config).unwrap();
7037
7038        assert!(
7039            model_rs.contains("dispatch_matmul_q8"),
7040            "Q8_0 model should use dispatch_matmul_q8 for projections"
7041        );
7042        assert!(
7043            model_rs.contains("fn dispatch_matmul_q8"),
7044            "model.rs should define dispatch_matmul_q8 method"
7045        );
7046    }
7047
7048    #[test]
7049    fn q8_model_loads_raw_bytes_not_dequantized() {
7050        let config = minimal_q8_config();
7051        let model_rs = generate_model_rs(&config).unwrap();
7052
7053        // Should NOT contain dequantization code
7054        assert!(
7055            !model_rs.contains("f16_to_f32"),
7056            "Q8_0 model should not dequantize weights to f32"
7057        );
7058        assert!(
7059            !model_rs.contains("f32_data"),
7060            "Q8_0 model should not create f32 weight data"
7061        );
7062
7063        // Should load raw Q8_0 bytes directly
7064        assert!(
7065            model_rs.contains("total_raw as u64"),
7066            "Q8_0 model should load raw bytes into Metal buffer"
7067        );
7068    }
7069
7070    #[test]
7071    fn q8_model_norms_stay_f32() {
7072        let config = minimal_q8_config();
7073        let model_rs = generate_model_rs(&config).unwrap();
7074
7075        // Norm weights should still use f32 buffers
7076        assert!(
7077            model_rs.contains("let attn_norm = next_f32_buffer"),
7078            "attn_norm should use f32 buffer even for Q8_0 models"
7079        );
7080        assert!(
7081            model_rs.contains("let ffn_norm = next_f32_buffer"),
7082            "ffn_norm should use f32 buffer even for Q8_0 models"
7083        );
7084        assert!(
7085            model_rs.contains("let norm_buf = next_f32_buffer"),
7086            "final norm should use f32 buffer even for Q8_0 models"
7087        );
7088    }
7089
7090    #[test]
7091    fn q8_model_uses_fused_weight_loading() {
7092        let config = minimal_q8_config();
7093        let model_rs = generate_model_rs(&config).unwrap();
7094
7095        // Should use fused Q8 buffer loading for QKV
7096        assert!(
7097            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7098            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7099        );
7100        // Should use fused Q8 buffer loading for gate+up
7101        assert!(
7102            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7103            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7104        );
7105        // Should still use regular q8 buffer for individual weights
7106        assert!(
7107            model_rs.contains("let o_weight = next_q8_buffer"),
7108            "Q8_0 model should use next_q8_buffer for O weight"
7109        );
7110        assert!(
7111            model_rs.contains("let down_weight = next_q8_buffer"),
7112            "Q8_0 model should use next_q8_buffer for down weight"
7113        );
7114    }
7115
7116    #[test]
7117    fn f32_model_does_not_use_q8_dispatch() {
7118        let config = minimal_config();
7119        let model_rs = generate_model_rs(&config).unwrap();
7120
7121        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
7122        let forward_start = model_rs
7123            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7124            .unwrap();
7125        let forward_body = &model_rs[forward_start..];
7126        let forward_end = forward_body
7127            .find("\n    fn dispatch_")
7128            .unwrap_or(forward_body.len());
7129        let forward_code = &forward_body[..forward_end];
7130
7131        assert!(
7132            !forward_code.contains("dispatch_matmul_q8"),
7133            "f32 model forward should not use dispatch_matmul_q8"
7134        );
7135    }
7136
7137    #[test]
7138    fn q8_dispatch_helper_takes_compute_encoder_ref() {
7139        let config = minimal_q8_config();
7140        let model_rs = generate_model_rs(&config).unwrap();
7141
7142        assert!(
7143            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7144            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7145        );
7146    }
7147
7148    #[test]
7149    fn generated_model_has_double_buffered_prefill() {
7150        let config = minimal_config();
7151        let model_rs = generate_model_rs(&config).unwrap();
7152
7153        // MetalModel should have prev_cmd field for double-buffered prefill
7154        assert!(
7155            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7156            "MetalModel should have prev_cmd field for double-buffered prefill"
7157        );
7158
7159        // Should have forward_prefill method
7160        assert!(
7161            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7162            "MetalModel should have forward_prefill method"
7163        );
7164
7165        // forward() should drain prev_cmd at the start
7166        assert!(
7167            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7168            "forward() should drain prev_cmd from previous prefill"
7169        );
7170    }
7171
7172    #[test]
7173    fn generated_main_rs_uses_forward_prefill_for_prompt() {
7174        let config = minimal_config();
7175        let main_rs = generate_main_rs("test-model", &config).unwrap();
7176
7177        assert!(
7178            main_rs.contains("forward_prefill"),
7179            "main.rs should use forward_prefill for intermediate prompt tokens"
7180        );
7181        assert!(
7182            main_rs.contains("double-buffered"),
7183            "main.rs should document double-buffered prefill"
7184        );
7185    }
7186
7187    #[test]
7188    fn generated_shaders_q8_uses_wide_vectorized_loads() {
7189        let shaders = generate_metal_shaders(&minimal_config());
7190
7191        assert!(
7192            shaders.contains("packed_short4"),
7193            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7194        );
7195        assert!(
7196            shaders.contains("d0[0]"),
7197            "matmul_vec_q8 should index the wide pointer for row 0"
7198        );
7199        assert!(
7200            shaders.contains("as_type<char2>"),
7201            "matmul_vec_q8 should bitcast short lanes to char2"
7202        );
7203        assert!(
7204            shaders.contains("dot("),
7205            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7206        );
7207    }
7208
7209    // ── Q4_0 tests ──────────────────────────────────────────────────────
7210
7211    fn minimal_q4_config() -> ModelConfig {
7212        ModelConfig {
7213            architecture: Architecture::Llama,
7214            hidden_size: 64,
7215            intermediate_size: 128,
7216            num_layers: 2,
7217            num_attention_heads: 4,
7218            num_kv_heads: 4,
7219            head_dim: 16,
7220            vocab_size: 256,
7221            max_seq_len: 512,
7222            rms_norm_eps: 1e-5,
7223            rope_theta: 10000.0,
7224            dtype: DType::Q4_0,
7225            sliding_window_size: None,
7226            qkv_bias: false,
7227        }
7228    }
7229
7230    #[test]
7231    fn generated_shaders_contain_q4_kernel() {
7232        let shaders = generate_metal_shaders(&minimal_config());
7233
7234        assert!(
7235            shaders.contains("kernel void matmul_vec_q4"),
7236            "shaders should contain matmul_vec_q4 kernel"
7237        );
7238        assert!(
7239            shaders.contains("Q4_ROWS_PER_TG"),
7240            "shaders should define Q4_ROWS_PER_TG constant"
7241        );
7242        assert!(
7243            shaders.contains("Q4_ROWS_PER_SG"),
7244            "shaders should define Q4_ROWS_PER_SG constant"
7245        );
7246    }
7247
7248    #[test]
7249    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7250        let shaders = generate_metal_shaders(&minimal_config());
7251
7252        // Q4_0 kernel should use uchar4 for packed byte loads
7253        assert!(
7254            shaders.contains("uchar4"),
7255            "matmul_vec_q4 should use uchar4 for packed byte loads"
7256        );
7257        // Should unpack nibbles with &0xF and >>4
7258        assert!(
7259            shaders.contains("&0xF"),
7260            "matmul_vec_q4 should extract low nibble with &0xF"
7261        );
7262        assert!(
7263            shaders.contains(">>4"),
7264            "matmul_vec_q4 should extract high nibble with >>4"
7265        );
7266        // Should subtract 8 to convert unsigned to signed
7267        assert!(
7268            shaders.contains("-8)"),
7269            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7270        );
7271        // Should use 18-byte block size
7272        assert!(
7273            shaders.contains("blk * 18"),
7274            "matmul_vec_q4 should use 18-byte block stride"
7275        );
7276    }
7277
7278    #[test]
7279    fn q4_model_has_matmul_q4_pipeline() {
7280        let config = minimal_q4_config();
7281        let model_rs = generate_model_rs(&config).unwrap();
7282
7283        assert!(
7284            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7285            "MetalModel should have matmul_q4_pipeline field"
7286        );
7287        assert!(
7288            model_rs.contains("matmul_q4_pipeline,"),
7289            "MetalModel Self should include matmul_q4_pipeline"
7290        );
7291    }
7292
7293    #[test]
7294    fn q4_model_uses_dispatch_matmul_q4() {
7295        let config = minimal_q4_config();
7296        let model_rs = generate_model_rs(&config).unwrap();
7297
7298        assert!(
7299            model_rs.contains("dispatch_matmul_q4"),
7300            "Q4_0 model should use dispatch_matmul_q4 for projections"
7301        );
7302        assert!(
7303            model_rs.contains("fn dispatch_matmul_q4"),
7304            "model.rs should define dispatch_matmul_q4 method"
7305        );
7306    }
7307
7308    #[test]
7309    fn q4_model_loads_raw_bytes_not_dequantized() {
7310        let config = minimal_q4_config();
7311        let model_rs = generate_model_rs(&config).unwrap();
7312
7313        // Should NOT contain dequantization code
7314        assert!(
7315            !model_rs.contains("f16_to_f32"),
7316            "Q4_0 model should not dequantize weights to f32"
7317        );
7318
7319        // Should load raw Q4_0 bytes directly
7320        assert!(
7321            model_rs.contains("total_raw as u64"),
7322            "Q4_0 model should load raw bytes into Metal buffer"
7323        );
7324    }
7325
7326    #[test]
7327    fn q4_model_norms_stay_f32() {
7328        let config = minimal_q4_config();
7329        let model_rs = generate_model_rs(&config).unwrap();
7330
7331        assert!(
7332            model_rs.contains("let attn_norm = next_f32_buffer"),
7333            "attn_norm should use f32 buffer even for Q4_0 models"
7334        );
7335        assert!(
7336            model_rs.contains("let ffn_norm = next_f32_buffer"),
7337            "ffn_norm should use f32 buffer even for Q4_0 models"
7338        );
7339        assert!(
7340            model_rs.contains("let norm_buf = next_f32_buffer"),
7341            "final norm should use f32 buffer even for Q4_0 models"
7342        );
7343    }
7344
7345    #[test]
7346    fn q4_model_uses_fused_weight_loading() {
7347        let config = minimal_q4_config();
7348        let model_rs = generate_model_rs(&config).unwrap();
7349
7350        assert!(
7351            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7352            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7353        );
7354        assert!(
7355            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7356            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7357        );
7358        assert!(
7359            model_rs.contains("let o_weight = next_q4_buffer"),
7360            "Q4_0 model should use next_q4_buffer for O weight"
7361        );
7362        assert!(
7363            model_rs.contains("let down_weight = next_q4_buffer"),
7364            "Q4_0 model should use next_q4_buffer for down weight"
7365        );
7366    }
7367
7368    #[test]
7369    fn qwen2_qkv_bias_wired_through_metal_codegen() {
7370        // Issue #210: the pre-v0.6.2 Metal codegen had zero handling for
7371        // qkv_bias.  Verify that a Qwen2-style config emits the bias buffer,
7372        // loader, pipeline, and dispatch call in the expected places.
7373        let config = ModelConfig {
7374            architecture: Architecture::Qwen2,
7375            qkv_bias: true,
7376            ..minimal_config()
7377        };
7378        let model_rs = generate_model_rs(&config).unwrap();
7379
7380        assert!(
7381            model_rs.contains("qkv_bias: Buffer"),
7382            "Qwen2 LayerBuffers must declare qkv_bias field"
7383        );
7384        assert!(
7385            model_rs.contains("let qkv_bias = next_f32_buffer"),
7386            "Qwen2 layer init must load the bias from the weight blob"
7387        );
7388        assert!(
7389            model_rs.contains("add_bias_batch_pipeline"),
7390            "Qwen2 model struct must include the add_bias_batch_pipeline"
7391        );
7392        assert!(
7393            model_rs.contains("fn dispatch_add_bias_batch"),
7394            "Qwen2 codegen must emit dispatch_add_bias_batch helper"
7395        );
7396        assert!(
7397            model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
7398            "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
7399        );
7400        assert!(
7401            model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
7402            "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
7403        );
7404
7405        // The add_bias_batch MSL kernel must be in the shader source.
7406        let shaders = generate_metal_shaders(&config);
7407        assert!(
7408            shaders.contains("kernel void add_bias_batch"),
7409            "shaders.metal must contain the add_bias_batch kernel"
7410        );
7411    }
7412
7413    #[test]
7414    fn llama_does_not_emit_qkv_bias_machinery() {
7415        // Negative test: non-Qwen2 models must NOT carry the bias dispatch,
7416        // buffer, or pipeline — keeps generated code lean for Llama/Phi/etc.
7417        let config = minimal_config();
7418        assert!(!config.qkv_bias);
7419        let model_rs = generate_model_rs(&config).unwrap();
7420        assert!(
7421            !model_rs.contains("qkv_bias: Buffer"),
7422            "Llama must not have qkv_bias field"
7423        );
7424        assert!(
7425            !model_rs.contains("add_bias_batch_pipeline"),
7426            "Llama must not pull in add_bias_batch_pipeline"
7427        );
7428        assert!(
7429            !model_rs.contains("dispatch_add_bias_batch"),
7430            "Llama must not call dispatch_add_bias_batch"
7431        );
7432    }
7433
7434    #[test]
7435    fn q4_dispatch_helper_takes_compute_encoder_ref() {
7436        let config = minimal_q4_config();
7437        let model_rs = generate_model_rs(&config).unwrap();
7438
7439        assert!(
7440            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
7441            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
7442        );
7443    }
7444
7445    #[test]
7446    fn f32_model_does_not_use_q4_dispatch() {
7447        let config = minimal_config();
7448        let model_rs = generate_model_rs(&config).unwrap();
7449
7450        let forward_start = model_rs
7451            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7452            .unwrap();
7453        let forward_body = &model_rs[forward_start..];
7454        let forward_end = forward_body
7455            .find("\n    fn dispatch_")
7456            .unwrap_or(forward_body.len());
7457        let forward_code = &forward_body[..forward_end];
7458
7459        assert!(
7460            !forward_code.contains("dispatch_matmul_q4"),
7461            "f32 model forward should not use dispatch_matmul_q4"
7462        );
7463    }
7464
7465    #[test]
7466    fn q4_model_lm_head_uses_q4_buffer() {
7467        let config = minimal_q4_config();
7468        let model_rs = generate_model_rs(&config).unwrap();
7469
7470        assert!(
7471            model_rs.contains("let lm_head_buf = next_q4_buffer"),
7472            "Q4_0 model should use next_q4_buffer for lm_head"
7473        );
7474    }
7475
7476    #[test]
7477    fn vec_tile_size_matches_model_dimensions() {
7478        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
7479        let small = minimal_config();
7480        let shaders_small = generate_metal_shaders(&small);
7481        assert!(
7482            shaders_small.contains("vec_tile[128]"),
7483            "vec_tile should be sized to max(hidden, intermediate) = 128"
7484        );
7485
7486        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
7487        let mut large = minimal_config();
7488        large.hidden_size = 2048;
7489        large.intermediate_size = 8192;
7490        let shaders_large = generate_metal_shaders(&large);
7491        assert!(
7492            shaders_large.contains("vec_tile[8192]"),
7493            "vec_tile should be 8192 for models with intermediate=8192"
7494        );
7495        assert!(
7496            !shaders_large.contains("vec_tile[4096]"),
7497            "vec_tile should NOT be hardcoded to 4096"
7498        );
7499    }
7500
7501    #[test]
7502    fn generated_cargo_toml_has_server_deps() {
7503        let toml = generate_cargo_toml("my-model");
7504        assert!(
7505            toml.contains("tiny_http"),
7506            "Cargo.toml should depend on tiny_http"
7507        );
7508        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
7509        assert!(
7510            toml.contains("serde_json"),
7511            "Cargo.toml should depend on serde_json"
7512        );
7513    }
7514
7515    #[test]
7516    fn generated_main_rs_has_serve_mode() {
7517        let config = minimal_config();
7518        let main_rs = generate_main_rs("test-model", &config).unwrap();
7519
7520        assert!(
7521            main_rs.contains("--serve"),
7522            "main.rs should parse --serve flag"
7523        );
7524        assert!(
7525            main_rs.contains("--port"),
7526            "main.rs should parse --port flag"
7527        );
7528        assert!(
7529            main_rs.contains("fn serve("),
7530            "main.rs should define serve function"
7531        );
7532        assert!(
7533            main_rs.contains("tiny_http::Server::http"),
7534            "main.rs should create tiny_http server"
7535        );
7536    }
7537
7538    #[test]
7539    fn generated_main_rs_has_chat_completions_endpoint() {
7540        let config = minimal_config();
7541        let main_rs = generate_main_rs("test-model", &config).unwrap();
7542
7543        assert!(
7544            main_rs.contains("/v1/chat/completions"),
7545            "main.rs should handle /v1/chat/completions endpoint"
7546        );
7547        assert!(
7548            main_rs.contains("/v1/models"),
7549            "main.rs should handle /v1/models endpoint"
7550        );
7551        assert!(
7552            main_rs.contains("/health"),
7553            "main.rs should handle /health endpoint"
7554        );
7555    }
7556
7557    #[test]
7558    fn generated_main_rs_has_sse_streaming() {
7559        let config = minimal_config();
7560        let main_rs = generate_main_rs("test-model", &config).unwrap();
7561
7562        assert!(
7563            main_rs.contains("text/event-stream"),
7564            "main.rs should set SSE content type for streaming"
7565        );
7566        assert!(
7567            main_rs.contains("chat.completion.chunk"),
7568            "main.rs should emit SSE chunks"
7569        );
7570        assert!(
7571            main_rs.contains("[DONE]"),
7572            "main.rs should emit [DONE] sentinel"
7573        );
7574    }
7575
7576    #[test]
7577    fn generated_main_rs_has_chat_message_formatting() {
7578        let config = minimal_config();
7579        let main_rs = generate_main_rs("test-model", &config).unwrap();
7580
7581        assert!(
7582            main_rs.contains("fn format_chat_messages"),
7583            "main.rs should define format_chat_messages function"
7584        );
7585        assert!(
7586            main_rs.contains("<|im_start|>"),
7587            "main.rs should use ChatML format"
7588        );
7589        assert!(
7590            main_rs.contains("<|im_end|>"),
7591            "main.rs should use ChatML format"
7592        );
7593    }
7594
7595    #[test]
7596    fn generated_main_rs_has_request_types() {
7597        let config = minimal_config();
7598        let main_rs = generate_main_rs("test-model", &config).unwrap();
7599
7600        assert!(
7601            main_rs.contains("struct ChatRequest"),
7602            "main.rs should define ChatRequest struct"
7603        );
7604        assert!(
7605            main_rs.contains("struct ChatMessage"),
7606            "main.rs should define ChatMessage struct"
7607        );
7608        assert!(
7609            main_rs.contains("Deserialize"),
7610            "main.rs should derive Deserialize for request types"
7611        );
7612    }
7613
7614    #[test]
7615    fn generated_model_has_reset_method() {
7616        let config = minimal_config();
7617        let model_rs = generate_model_rs(&config).unwrap();
7618
7619        assert!(
7620            model_rs.contains("pub fn reset(&mut self)"),
7621            "model.rs should have a reset() method for multi-request serving"
7622        );
7623        assert!(
7624            model_rs.contains("self.pos = 0"),
7625            "reset() should reset position to 0"
7626        );
7627    }
7628
7629    #[test]
7630    fn generated_main_rs_cli_mode_still_works() {
7631        let config = minimal_config();
7632        let main_rs = generate_main_rs("test-model", &config).unwrap();
7633
7634        // CLI mode should still be functional
7635        assert!(
7636            main_rs.contains("fn cli_mode("),
7637            "main.rs should define cli_mode function"
7638        );
7639        assert!(
7640            main_rs.contains("model.forward"),
7641            "main.rs should call model.forward"
7642        );
7643        assert!(
7644            main_rs.contains("model.forward_prefill"),
7645            "main.rs should call model.forward_prefill"
7646        );
7647    }
7648
7649    // ── Batched prefill tests ──────────────────────────────────────────
7650
7651    #[test]
7652    fn generated_shaders_contain_batch_kernels() {
7653        let shaders = generate_metal_shaders(&minimal_config());
7654
7655        assert!(
7656            shaders.contains("kernel void matmul_vec_batch"),
7657            "shaders should contain matmul_vec_batch kernel"
7658        );
7659        assert!(
7660            shaders.contains("kernel void matmul_vec_q8_batch"),
7661            "shaders should contain matmul_vec_q8_batch kernel"
7662        );
7663        assert!(
7664            shaders.contains("kernel void matmul_q8_gemm_batch"),
7665            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
7666        );
7667        assert!(
7668            shaders.contains("kernel void matmul_vec_q4_batch"),
7669            "shaders should contain matmul_vec_q4_batch kernel"
7670        );
7671        assert!(
7672            shaders.contains("kernel void rms_norm_batch"),
7673            "shaders should contain rms_norm_batch kernel"
7674        );
7675        assert!(
7676            shaders.contains("kernel void silu_mul_fused_batch"),
7677            "shaders should contain silu_mul_fused_batch kernel"
7678        );
7679        assert!(
7680            shaders.contains("kernel void add_inplace_batch"),
7681            "shaders should contain add_inplace_batch kernel"
7682        );
7683        assert!(
7684            shaders.contains("kernel void copy_embedding_batch"),
7685            "shaders should contain copy_embedding_batch kernel"
7686        );
7687    }
7688
7689    #[test]
7690    fn generated_model_has_batch_pipelines() {
7691        let config = minimal_config();
7692        let model_rs = generate_model_rs(&config).unwrap();
7693
7694        for pipeline in &[
7695            "matmul_batch_pipeline",
7696            "matmul_q8_batch_pipeline",
7697            "matmul_q8_gemm_batch_pipeline",
7698            "matmul_q4_batch_pipeline",
7699            "rms_norm_batch_pipeline",
7700            "rope_batch_pipeline",
7701            "silu_mul_fused_batch_pipeline",
7702            "add_inplace_batch_pipeline",
7703            "copy_embedding_batch_pipeline",
7704            "attention_batch_pipeline",
7705            "copy_kv_batch_pipeline",
7706        ] {
7707            assert!(
7708                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
7709                "MetalModel should have {pipeline} field"
7710            );
7711        }
7712    }
7713
7714    #[test]
7715    fn generated_model_has_batch_buffers() {
7716        let config = minimal_config();
7717        let model_rs = generate_model_rs(&config).unwrap();
7718
7719        for buf in &[
7720            "batch_hidden_buf",
7721            "batch_residual_buf",
7722            "batch_qkv_buf",
7723            "batch_attn_out_buf",
7724            "batch_attn_proj_buf",
7725            "batch_gate_up_buf",
7726            "batch_ffn_hidden_buf",
7727            "batch_ffn_out_buf",
7728            "batch_tokens_buf",
7729            "batch_positions_buf",
7730        ] {
7731            assert!(
7732                model_rs.contains(&format!("{buf}: Buffer")),
7733                "MetalModel should have {buf} field"
7734            );
7735        }
7736    }
7737
7738    #[test]
7739    fn generated_model_has_forward_prefill_batch() {
7740        let config = minimal_config();
7741        let model_rs = generate_model_rs(&config).unwrap();
7742
7743        assert!(
7744            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
7745            "MetalModel should have forward_prefill_batch method"
7746        );
7747
7748        // forward_prefill should delegate to forward_prefill_batch
7749        assert!(
7750            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
7751            "forward_prefill should delegate to forward_prefill_batch"
7752        );
7753    }
7754
7755    #[test]
7756    fn generated_model_has_max_batch_size_constant() {
7757        let config = minimal_config();
7758        let model_rs = generate_model_rs(&config).unwrap();
7759
7760        assert!(
7761            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
7762            "model.rs should define MAX_BATCH_SIZE constant"
7763        );
7764    }
7765
7766    #[test]
7767    fn forward_prefill_batch_uses_batch_dispatch() {
7768        let config = minimal_config();
7769        let model_rs = generate_model_rs(&config).unwrap();
7770
7771        let batch_start = model_rs
7772            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7773            .unwrap();
7774        let batch_body = &model_rs[batch_start..];
7775        let batch_end = batch_body
7776            .find("\n    pub fn reset")
7777            .unwrap_or(batch_body.len());
7778        let batch_code = &batch_body[..batch_end];
7779
7780        // Should use batched dispatch methods
7781        assert!(
7782            batch_code.contains("dispatch_rms_norm_batch"),
7783            "forward_prefill_batch should use dispatch_rms_norm_batch"
7784        );
7785        assert!(
7786            batch_code.contains("dispatch_copy_embedding_batch"),
7787            "forward_prefill_batch should use dispatch_copy_embedding_batch"
7788        );
7789        assert!(
7790            batch_code.contains("dispatch_silu_mul_fused_batch"),
7791            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
7792        );
7793        // Should use batched causal attention dispatch
7794        assert!(
7795            batch_code.contains("dispatch_attention_batch"),
7796            "forward_prefill_batch should use dispatch_attention_batch"
7797        );
7798        // Should use fused KV cache copy (both K and V in one dispatch)
7799        assert!(
7800            batch_code.contains("dispatch_copy_kv_both_batch"),
7801            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
7802        );
7803        // Should use fused RoPE Q+K dispatch
7804        assert!(
7805            batch_code.contains("dispatch_rope_qk_batch"),
7806            "forward_prefill_batch should use dispatch_rope_qk_batch"
7807        );
7808    }
7809
7810    #[test]
7811    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
7812        let config = minimal_q8_config();
7813        let model_rs = generate_model_rs(&config).unwrap();
7814
7815        let batch_start = model_rs
7816            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7817            .unwrap();
7818        let batch_body = &model_rs[batch_start..];
7819        let batch_end = batch_body
7820            .find("\n    pub fn reset")
7821            .unwrap_or(batch_body.len());
7822        let batch_code = &batch_body[..batch_end];
7823
7824        assert!(
7825            batch_code.contains("dispatch_matmul_q8_batch"),
7826            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
7827        );
7828    }
7829
7830    #[test]
7831    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
7832        let config = minimal_q4_config();
7833        let model_rs = generate_model_rs(&config).unwrap();
7834
7835        let batch_start = model_rs
7836            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7837            .unwrap();
7838        let batch_body = &model_rs[batch_start..];
7839        let batch_end = batch_body
7840            .find("\n    pub fn reset")
7841            .unwrap_or(batch_body.len());
7842        let batch_code = &batch_body[..batch_end];
7843
7844        assert!(
7845            batch_code.contains("dispatch_matmul_q4_batch"),
7846            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
7847        );
7848    }
7849
7850    #[test]
7851    fn generated_main_rs_uses_batched_prefill() {
7852        let config = minimal_config();
7853        let main_rs = generate_main_rs("test-model", &config).unwrap();
7854
7855        assert!(
7856            main_rs.contains("forward_prefill_batch"),
7857            "main.rs should use forward_prefill_batch for prompt tokens"
7858        );
7859    }
7860
7861    #[test]
7862    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
7863        let config = minimal_config();
7864        let model_rs = generate_model_rs(&config).unwrap();
7865
7866        let batch_start = model_rs
7867            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7868            .unwrap();
7869        let batch_body = &model_rs[batch_start..];
7870        let batch_end = batch_body
7871            .find("\n    pub fn reset")
7872            .unwrap_or(batch_body.len());
7873        let batch_code = &batch_body[..batch_end];
7874
7875        assert!(
7876            batch_code.contains("dispatch_matmul_batch"),
7877            "f32 forward_prefill_batch should use dispatch_matmul_batch"
7878        );
7879        // Should NOT use Q8 or Q4 batch dispatch
7880        assert!(
7881            !batch_code.contains("dispatch_matmul_q8_batch"),
7882            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
7883        );
7884        assert!(
7885            !batch_code.contains("dispatch_matmul_q4_batch"),
7886            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
7887        );
7888    }
7889
7890    #[test]
7891    fn forward_uses_cpu_embedding_lookup() {
7892        let config = minimal_config();
7893        let model_rs = generate_model_rs(&config).unwrap();
7894
7895        // Find just the forward() body (not forward_profile)
7896        let forward_start = model_rs
7897            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7898            .unwrap();
7899        let forward_body = &model_rs[forward_start..];
7900        let forward_end = forward_body
7901            .find("\n    pub fn forward_profile")
7902            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7903            .unwrap_or(forward_body.len());
7904        let forward_code = &forward_body[..forward_end];
7905
7906        // forward() should use CPU memcpy for embedding lookup (unified memory)
7907        assert!(
7908            forward_code.contains("embed_buf.contents()"),
7909            "forward() should access embed_buf via CPU unified memory for embedding lookup"
7910        );
7911        assert!(
7912            forward_code.contains("copy_nonoverlapping"),
7913            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
7914        );
7915        // forward() should NOT use GPU dispatch for embedding
7916        assert!(
7917            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
7918            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
7919        );
7920    }
7921
7922    #[test]
7923    fn forward_profile_method_exists() {
7924        let config = minimal_config();
7925        let model_rs = generate_model_rs(&config).unwrap();
7926
7927        assert!(
7928            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
7929            "MetalModel should have forward_profile() method"
7930        );
7931        // Profile method should print timing information
7932        assert!(
7933            model_rs.contains("[profile]"),
7934            "forward_profile() should print timing with [profile] prefix"
7935        );
7936        assert!(
7937            model_rs.contains("d_embed"),
7938            "forward_profile() should measure embedding time"
7939        );
7940        assert!(
7941            model_rs.contains("d_layers"),
7942            "forward_profile() should measure layer time"
7943        );
7944        assert!(
7945            model_rs.contains("d_logits"),
7946            "forward_profile() should measure logits time"
7947        );
7948    }
7949
7950    #[test]
7951    fn generated_cli_has_profile_flag() {
7952        let config = minimal_config();
7953        let main_rs = generate_main_rs("test-model", &config).unwrap();
7954
7955        assert!(
7956            main_rs.contains("--profile"),
7957            "CLI should support --profile flag"
7958        );
7959        assert!(
7960            main_rs.contains("forward_profile"),
7961            "CLI should call forward_profile when --profile is set"
7962        );
7963    }
7964
7965    #[test]
7966    fn generated_cli_has_thermal_yield() {
7967        let config = minimal_config();
7968        let main_rs = generate_main_rs("test-model", &config).unwrap();
7969
7970        assert!(
7971            main_rs.contains("yield_now()"),
7972            "CLI generation loop should include thread::yield_now() for thermal management"
7973        );
7974    }
7975
7976    // ── Real-world validation tests ──────────────────────────────────────
7977
7978    #[test]
7979    fn generated_forward_handles_single_token_prompt() {
7980        // With a single token (the first prompt token), forward() should work
7981        // at pos=0 where seq_len=1. The attention kernel must handle the case
7982        // where there is only one KV entry (no prefill context).
7983        let config = minimal_config();
7984        let model_rs = generate_model_rs(&config).unwrap();
7985
7986        // The forward function should accept any u32 token_id (no minimum pos guard)
7987        let forward_start = model_rs
7988            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7989            .expect("forward() must exist");
7990        let forward_body = &model_rs[forward_start..forward_start + 400];
7991
7992        // Should NOT require pos > 0 or seq_len > 1
7993        assert!(
7994            !forward_body.contains("assert!(self.pos > 0"),
7995            "forward() must accept pos=0 (first token with no prefill)"
7996        );
7997
7998        // The attention kernel should handle seq_len=1 via the pos field
7999        assert!(
8000            model_rs.contains("self.pos"),
8001            "forward() should use self.pos to track sequence position"
8002        );
8003    }
8004
8005    #[test]
8006    fn generated_reset_clears_kv_cache_position() {
8007        // After reset(), the model should be in a clean state. The pos field
8008        // must be 0 so new generation starts from scratch.
8009        let config = minimal_config();
8010        let model_rs = generate_model_rs(&config).unwrap();
8011
8012        let reset_start = model_rs
8013            .find("pub fn reset(&mut self)")
8014            .expect("reset() must exist");
8015        let reset_body = &model_rs[reset_start..reset_start + 200];
8016
8017        // Reset must zero the position counter
8018        assert!(
8019            reset_body.contains("self.pos = 0"),
8020            "reset() must set self.pos = 0"
8021        );
8022
8023        // Verify reset clears prev_cmd (double-buffering state)
8024        assert!(
8025            reset_body.contains("self.prev_cmd = None"),
8026            "reset() should clear prev_cmd for clean command buffer state"
8027        );
8028    }
8029
8030    #[test]
8031    fn generated_serve_handles_empty_messages_gracefully() {
8032        // The serve endpoint should not crash when receiving an empty messages array.
8033        // The format_chat_messages function should handle this gracefully.
8034        let config = minimal_config();
8035        let main_rs = generate_main_rs("test-model", &config).unwrap();
8036
8037        // The format_chat_messages function should exist and handle empty input
8038        let format_fn_start = main_rs
8039            .find("fn format_chat_messages")
8040            .expect("format_chat_messages must exist");
8041        let format_fn_body =
8042            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8043
8044        // It should iterate over messages (an empty slice produces an empty loop)
8045        assert!(
8046            format_fn_body.contains("for msg in messages"),
8047            "format_chat_messages should iterate over the messages slice"
8048        );
8049        // It should always append the assistant prompt suffix
8050        assert!(
8051            format_fn_body.contains("<|im_start|>assistant"),
8052            "format_chat_messages should always append assistant prompt header"
8053        );
8054
8055        // The serve function should call model.reset() before each request
8056        let serve_fn_start = main_rs
8057            .find("fn serve(")
8058            .expect("serve function must exist");
8059        let serve_fn_body = &main_rs[serve_fn_start..];
8060        assert!(
8061            serve_fn_body.contains("model.reset()"),
8062            "serve function should reset model between requests"
8063        );
8064    }
8065
8066    #[test]
8067    fn generated_model_forward_increments_pos() {
8068        // Each forward() call must increment self.pos so the next token
8069        // uses the correct RoPE position and KV cache offset.
8070        let config = minimal_config();
8071        let model_rs = generate_model_rs(&config).unwrap();
8072
8073        let forward_start = model_rs
8074            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8075            .unwrap();
8076        let forward_body = &model_rs[forward_start..];
8077        let forward_end = forward_body
8078            .find("\n    pub fn forward_profile")
8079            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
8080            .or_else(|| forward_body.find("\n    fn dispatch_"))
8081            .unwrap_or(forward_body.len());
8082        let forward_code = &forward_body[..forward_end];
8083
8084        assert!(
8085            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8086            "forward() must increment self.pos after processing a token"
8087        );
8088    }
8089}