Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    r#"//
122// Auto-generated by ForgeLLM Metal codegen.
123// Metal Shading Language compute kernels for transformer inference.
124//
125// Optimized with simdgroup cooperative reductions, shared memory vector
126// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
127// lane, and fast:: math intrinsics for Apple Silicon throughput.
128//
129
130#include <metal_stdlib>
131#include <metal_simdgroup_matrix>
132using namespace metal;
133
134// ── Constants ───────────────────────────────────────────────────────────
135// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
136// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
137// per shared memory load vs 4-row, improving ILP and reducing launches.
138constant constexpr uint SIMDGROUPS_PER_TG = 8;
139constant constexpr uint ROWS_PER_SIMDGROUP = 8;
140constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
141
142// ── matmul_vec ──────────────────────────────────────────────────────────
143// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
144// Uses simdgroup cooperative dot product with shared memory vector caching
145// and float4 vectorized loads. Each simdgroup processes 8 rows for better
146// shared memory reuse (8x vector reuse per load) and instruction-level
147// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
148kernel void matmul_vec(
149    device const float* matrix [[buffer(0)]],
150    device const float* vector [[buffer(1)]],
151    device float* output       [[buffer(2)]],
152    constant uint& rows        [[buffer(3)]],
153    constant uint& cols        [[buffer(4)]],
154    uint tgid [[threadgroup_position_in_grid]],
155    uint tid [[thread_index_in_threadgroup]],
156    uint simd_lane [[thread_index_in_simdgroup]],
157    uint simd_id [[simdgroup_index_in_threadgroup]])
158{
159    // Cooperatively load vector into threadgroup shared memory
160    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
161    for (uint i = tid; i < cols; i += 256) {
162        vec_tile[i] = vector[i];
163    }
164    threadgroup_barrier(mem_flags::mem_threadgroup);
165
166    // Each simdgroup handles 8 consecutive rows
167    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
168    if (row_base >= rows) return;
169
170    uint base0 = row_base * cols;
171    uint base1 = (row_base + 1) * cols;
172    uint base2 = (row_base + 2) * cols;
173    uint base3 = (row_base + 3) * cols;
174    uint base4 = (row_base + 4) * cols;
175    uint base5 = (row_base + 5) * cols;
176    uint base6 = (row_base + 6) * cols;
177    uint base7 = (row_base + 7) * cols;
178
179    // float4 vectorized accumulation across 8 rows
180    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
181    float4 sum4_0 = float4(0.0f);
182    float4 sum4_1 = float4(0.0f);
183    float4 sum4_2 = float4(0.0f);
184    float4 sum4_3 = float4(0.0f);
185    float4 sum4_4 = float4(0.0f);
186    float4 sum4_5 = float4(0.0f);
187    float4 sum4_6 = float4(0.0f);
188    float4 sum4_7 = float4(0.0f);
189
190    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
191        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
192        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
193        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
194        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
195        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
196        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
197        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
198        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
199        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
200    }
201
202    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
203    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
204    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
205    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
206    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
207    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
208    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
209    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
210
211    // Handle remaining elements (cols not divisible by 128)
212    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
213        float vv = vec_tile[j];
214        sum0 += matrix[base0 + j] * vv;
215        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
216        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
217        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
218        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
219        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
220        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
221        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
222    }
223
224    // Simdgroup hardware warp-level reduction
225    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
226    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
227    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
228    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
229
230    // Only first lane writes the results
231    if (simd_lane == 0) {
232        if (row_base     < rows) output[row_base]     = sum0;
233        if (row_base + 1 < rows) output[row_base + 1] = sum1;
234        if (row_base + 2 < rows) output[row_base + 2] = sum2;
235        if (row_base + 3 < rows) output[row_base + 3] = sum3;
236        if (row_base + 4 < rows) output[row_base + 4] = sum4;
237        if (row_base + 5 < rows) output[row_base + 5] = sum5;
238        if (row_base + 6 < rows) output[row_base + 6] = sum6;
239        if (row_base + 7 < rows) output[row_base + 7] = sum7;
240    }
241}
242
243// ── rms_norm ────────────────────────────────────────────────────────────
244// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
245// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
246// via shared memory for minimal synchronization overhead.
247kernel void rms_norm(
248    device const float* input   [[buffer(0)]],
249    device const float* weight  [[buffer(1)]],
250    device float* output        [[buffer(2)]],
251    constant uint& n            [[buffer(3)]],
252    constant float& eps         [[buffer(4)]],
253    uint tid [[thread_index_in_threadgroup]])
254{
255    // Each thread accumulates partial sum-of-squares
256    float sum_sq = 0.0f;
257    for (uint i = tid; i < n; i += 256) {
258        float v = input[i];
259        sum_sq += v * v;
260    }
261
262    // Simdgroup-level reduction (hardware warp sum)
263    sum_sq = simd_sum(sum_sq);
264
265    // Cross-simdgroup reduction via shared memory
266    threadgroup float shared[8];
267    uint simd_id = tid / 32;
268    uint simd_lane = tid % 32;
269    if (simd_lane == 0) {
270        shared[simd_id] = sum_sq;
271    }
272    threadgroup_barrier(mem_flags::mem_threadgroup);
273
274    // First thread computes final inverse RMS
275    if (tid == 0) {
276        float total = 0.0f;
277        for (uint i = 0; i < 8; i++) {
278            total += shared[i];
279        }
280        shared[0] = fast::rsqrt(total / float(n) + eps);
281    }
282    threadgroup_barrier(mem_flags::mem_threadgroup);
283
284    float inv_rms = shared[0];
285
286    // Normalize
287    for (uint i = tid; i < n; i += 256) {
288        output[i] = input[i] * inv_rms * weight[i];
289    }
290}
291
292// ── rope ────────────────────────────────────────────────────────────────
293// Rotary Position Embedding applied in-place.
294// Each thread handles one (head, pair) combination.
295kernel void rope(
296    device float* data        [[buffer(0)]],
297    constant uint& num_heads  [[buffer(1)]],
298    constant uint& head_dim   [[buffer(2)]],
299    constant uint& pos        [[buffer(3)]],
300    constant float& theta     [[buffer(4)]],
301    uint id [[thread_position_in_grid]])
302{
303    uint half_dim = head_dim / 2;
304    uint total_pairs = num_heads * half_dim;
305    if (id >= total_pairs) return;
306
307    uint h = id / half_dim;
308    uint i = id % half_dim;
309    uint off = h * head_dim;
310
311    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
312    float angle = float(pos) * freq;
313    float c = cos(angle);
314    float s = sin(angle);
315
316    float x0 = data[off + 2 * i];
317    float x1 = data[off + 2 * i + 1];
318    data[off + 2 * i]     = x0 * c - x1 * s;
319    data[off + 2 * i + 1] = x0 * s + x1 * c;
320}
321
322// ── softmax ─────────────────────────────────────────────────────────────
323// Numerically stable softmax over a 1-D array.
324// Single-threadgroup kernel with cooperative reduction.
325kernel void softmax(
326    device float* data       [[buffer(0)]],
327    constant uint& n         [[buffer(1)]],
328    uint tid [[thread_index_in_threadgroup]],
329    uint tg_size [[threads_per_threadgroup]])
330{
331    threadgroup float shared_val[256];
332
333    // Pass 1: find max
334    float local_max = -INFINITY;
335    for (uint i = tid; i < n; i += tg_size) {
336        local_max = max(local_max, data[i]);
337    }
338    shared_val[tid] = local_max;
339    threadgroup_barrier(mem_flags::mem_threadgroup);
340
341    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
342        if (tid < stride) {
343            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
344        }
345        threadgroup_barrier(mem_flags::mem_threadgroup);
346    }
347    float max_val = shared_val[0];
348    threadgroup_barrier(mem_flags::mem_threadgroup);
349
350    // Pass 2: exp and sum
351    float local_sum = 0.0f;
352    for (uint i = tid; i < n; i += tg_size) {
353        float e = fast::exp(data[i] - max_val);
354        data[i] = e;
355        local_sum += e;
356    }
357    shared_val[tid] = local_sum;
358    threadgroup_barrier(mem_flags::mem_threadgroup);
359
360    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
361        if (tid < stride) {
362            shared_val[tid] += shared_val[tid + stride];
363        }
364        threadgroup_barrier(mem_flags::mem_threadgroup);
365    }
366    float inv_sum = 1.0f / shared_val[0];
367    threadgroup_barrier(mem_flags::mem_threadgroup);
368
369    // Pass 3: normalize
370    for (uint i = tid; i < n; i += tg_size) {
371        data[i] *= inv_sum;
372    }
373}
374
375// ── silu_mul ────────────────────────────────────────────────────────────
376// Fused SiLU activation * element-wise multiply:
377//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
378kernel void silu_mul(
379    device const float* gate [[buffer(0)]],
380    device const float* up   [[buffer(1)]],
381    device float* output     [[buffer(2)]],
382    constant uint& n         [[buffer(3)]],
383    uint id [[thread_position_in_grid]])
384{
385    if (id >= n) return;
386    float g = gate[id];
387    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
388}
389
390// ── silu_mul_fused ─────────────────────────────────────────────────────
391// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
392//   gate = gate_up[0..n], up = gate_up[n..2*n]
393//   output[i] = silu(gate_up[i]) * gate_up[n + i]
394kernel void silu_mul_fused(
395    device const float* gate_up [[buffer(0)]],
396    device float* output        [[buffer(1)]],
397    constant uint& n            [[buffer(2)]],
398    uint id [[thread_position_in_grid]])
399{
400    if (id >= n) return;
401    float g = gate_up[id];
402    float u = gate_up[n + id];
403    output[id] = (g / (1.0f + fast::exp(-g))) * u;
404}
405
406// ── elementwise_add ─────────────────────────────────────────────────────
407// Residual connection: output[i] = a[i] + b[i]
408kernel void elementwise_add(
409    device const float* a  [[buffer(0)]],
410    device const float* b  [[buffer(1)]],
411    device float* output   [[buffer(2)]],
412    constant uint& n       [[buffer(3)]],
413    uint id [[thread_position_in_grid]])
414{
415    if (id >= n) return;
416    output[id] = a[id] + b[id];
417}
418
419// ── copy_buffer ─────────────────────────────────────────────────────────
420// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
421// transitions. Used for KV cache updates and embedding lookup.
422kernel void copy_buffer(
423    device const float* src [[buffer(0)]],
424    device float* dst       [[buffer(1)]],
425    constant uint& count    [[buffer(2)]],
426    uint id [[thread_position_in_grid]])
427{
428    if (id < count) dst[id] = src[id];
429}
430
431// ── copy_offset ─────────────────────────────────────────────────────────
432// Copy with source offset (in floats). Used for embedding table lookup
433// where we need to copy a specific row from a large table.
434kernel void copy_offset(
435    device const float* src     [[buffer(0)]],
436    device float* dst           [[buffer(1)]],
437    constant uint& src_offset   [[buffer(2)]],  // in floats
438    constant uint& count        [[buffer(3)]],
439    uint id [[thread_position_in_grid]])
440{
441    if (id < count) dst[id] = src[src_offset + id];
442}
443
444// ── add_inplace ─────────────────────────────────────────────────────────
445// In-place residual connection: a[i] += b[i]
446// Avoids a separate blit copy for residual add, reducing encoder overhead.
447kernel void add_inplace(
448    device float* a        [[buffer(0)]],
449    device const float* b  [[buffer(1)]],
450    constant uint& n       [[buffer(2)]],
451    uint id [[thread_position_in_grid]])
452{
453    if (id >= n) return;
454    a[id] += b[id];
455}
456
457// ── matmul_vec_q8 ─────────────────────────────────────────────────────
458// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
459// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
460// Operates directly on quantized weights to halve memory bandwidth vs f32,
461// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
462//
463// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
464// because int8->float conversion doubles register demand.  Fully unrolled
465// inner loop with float4 vector loads from shared memory eliminates loop
466// overhead and enables better instruction scheduling.
467// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
468constant constexpr uint Q8_ROWS_PER_SG = 4;
469constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
470
471// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
472// doubles ALU work, so the same register budget applies).
473constant constexpr uint Q4_ROWS_PER_SG = 4;
474constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
475
476kernel void matmul_vec_q8(
477    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
478    device const float* vector   [[buffer(1)]],  // f32 input
479    device float* output         [[buffer(2)]],
480    constant uint& rows          [[buffer(3)]],
481    constant uint& cols          [[buffer(4)]],  // number of elements per row
482    uint tgid [[threadgroup_position_in_grid]],
483    uint tid [[thread_index_in_threadgroup]],
484    uint simd_lane [[thread_index_in_simdgroup]],
485    uint simd_id [[simdgroup_index_in_threadgroup]])
486{
487    // Load vector into shared memory
488    threadgroup float vec_tile[VEC_TILE_SIZE];
489    for (uint i = tid; i < cols; i += 256) {
490        vec_tile[i] = vector[i];
491    }
492    threadgroup_barrier(mem_flags::mem_threadgroup);
493
494    // Each simdgroup handles 4 consecutive rows (lower register pressure)
495    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
496    if (row_base >= rows) return;
497
498    // Q8_0: each block is 34 bytes for 32 elements
499    uint blocks_per_row = cols / 32;
500    uint row_bytes = blocks_per_row * 34;
501
502    // Pointers to each row's Q8_0 data
503    device const uchar* r0 = matrix + row_base * row_bytes;
504    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
505    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
506    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
507
508    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
509
510    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
511        uint bb = blk * 34;
512        uint vb = blk * 32;
513
514        // Prefetch all 4 scales
515        float sc0 = float(*(device const half*)(r0 + bb));
516        float sc1 = float(*(device const half*)(r1 + bb));
517        float sc2 = float(*(device const half*)(r2 + bb));
518        float sc3 = float(*(device const half*)(r3 + bb));
519
520        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
521        // Q8_0 block layout where the int8 data starts at offset +2 from a
522        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
523        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
524        // reduction in memory transactions. Metal's char16/packed_char16 are
525        // reserved types and packed_*int4 require >=4-byte alignment which
526        // this layout does not provide, so packed_short4 is the widest valid
527        // vectorized load.
528        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
529        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
530        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
531        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
532
533        // Load all 8 float4 vector values for this 32-element block from shared memory
534        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
535        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
536        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
537        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
538        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
539        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
540        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
541        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
542
543        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
544        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
545        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
546            short4 _s = short4(SHORT4); \
547            char2 _a = as_type<char2>(_s.x); \
548            char2 _b = as_type<char2>(_s.y); \
549            char2 _c = as_type<char2>(_s.z); \
550            char2 _d = as_type<char2>(_s.w); \
551            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
552            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
553        }
554
555        float4 f0, f1;
556        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
557
558        // Row 0: 4 short4 loads cover 32 int8 weights
559        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
560        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
561        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
562        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
563
564        // Row 1
565        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
566        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
567        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
568        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
569
570        // Row 2
571        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
572        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
573        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
574        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
575
576        // Row 3
577        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
578        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
579        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
580        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
581
582        #undef Q8_UNPACK8
583
584        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
585    }
586
587    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
588    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
589
590    if (simd_lane == 0) {
591        if (row_base     < rows) output[row_base]     = sum0;
592        if (row_base + 1 < rows) output[row_base + 1] = sum1;
593        if (row_base + 2 < rows) output[row_base + 2] = sum2;
594        if (row_base + 3 < rows) output[row_base + 3] = sum3;
595    }
596}
597
598// ── matmul_vec_q4 ─────────────────────────────────────────────────────
599// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
600// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
601// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
602// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
603//
604// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
605// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
606kernel void matmul_vec_q4(
607    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
608    device const float* vector   [[buffer(1)]],  // f32 input
609    device float* output         [[buffer(2)]],
610    constant uint& rows          [[buffer(3)]],
611    constant uint& cols          [[buffer(4)]],  // number of elements per row
612    uint tgid [[threadgroup_position_in_grid]],
613    uint tid [[thread_index_in_threadgroup]],
614    uint simd_lane [[thread_index_in_simdgroup]],
615    uint simd_id [[simdgroup_index_in_threadgroup]])
616{
617    // Load vector into shared memory
618    threadgroup float vec_tile[VEC_TILE_SIZE];
619    for (uint i = tid; i < cols; i += 256) {
620        vec_tile[i] = vector[i];
621    }
622    threadgroup_barrier(mem_flags::mem_threadgroup);
623
624    // Each simdgroup handles 4 consecutive rows
625    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
626    if (row_base >= rows) return;
627
628    // Q4_0: each block is 18 bytes for 32 elements
629    uint blocks_per_row = cols / 32;
630    uint row_bytes = blocks_per_row * 18;
631
632    // Pointers to each row's Q4_0 data
633    device const uchar* r0 = matrix + row_base * row_bytes;
634    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
635    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
636    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
637
638    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
639
640    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
641        uint bb = blk * 18;
642        uint vb = blk * 32;
643
644        // Prefetch all 4 scales
645        float sc0 = float(*(device const half*)(r0 + bb));
646        float sc1 = float(*(device const half*)(r1 + bb));
647        float sc2 = float(*(device const half*)(r2 + bb));
648        float sc3 = float(*(device const half*)(r3 + bb));
649
650        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
651        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
652        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
653        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
654        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
655
656        // Load 8 float4 vector values for 32 elements from shared memory
657        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
658        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
659        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
660        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
661        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
662        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
663        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
664        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
665        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
666
667        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
668        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
669        float bd0=0, bd1=0, bd2=0, bd3=0;
670        uchar4 b;
671
672        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
673        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
674                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
675                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
676                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
677        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
678                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
679                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
680                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
681        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
682                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
683                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
684                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
685        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
686                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
687                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
688                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
689
690        // Row 1
691        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
692                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
693                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
694                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
695        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
696                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
697                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
698                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
699        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
700                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
701                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
702                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
703        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
704                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
705                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
706                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
707
708        // Row 2
709        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
710                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
711                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
712                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
713        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
714                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
715                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
716                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
717        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
718                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
719                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
720                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
721        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
722                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
723                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
724                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
725
726        // Row 3
727        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
728                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
729                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
730                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
731        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
732                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
733                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
734                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
735        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
736                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
737                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
738                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
739        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
740                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
741                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
742                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
743
744        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
745    }
746
747    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
748    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
749
750    if (simd_lane == 0) {
751        if (row_base     < rows) output[row_base]     = sum0;
752        if (row_base + 1 < rows) output[row_base + 1] = sum1;
753        if (row_base + 2 < rows) output[row_base + 2] = sum2;
754        if (row_base + 3 < rows) output[row_base + 3] = sum3;
755    }
756}
757
758// ── attention ───────────────────────────────────────────────────────────
759// Single-query attention with simdgroup cooperative reductions.
760// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
761// with simd_max/simd_sum reductions, then weighted sum of V.
762// Each threadgroup handles one head with 256 threads (8 simdgroups).
763//
764// Buffers:
765//   q:       [num_heads * head_dim]       current query
766//   k_cache: [max_seq_len * num_kv_heads * head_dim]
767//   v_cache: [max_seq_len * num_kv_heads * head_dim]
768//   output:  [num_heads * head_dim]
769kernel void attention(
770    device const float* q        [[buffer(0)]],
771    device const float* k_cache  [[buffer(1)]],
772    device const float* v_cache  [[buffer(2)]],
773    device float* output         [[buffer(3)]],
774    constant uint& seq_len       [[buffer(4)]],
775    constant uint& num_heads     [[buffer(5)]],
776    constant uint& num_kv_heads  [[buffer(6)]],
777    constant uint& head_dim      [[buffer(7)]],
778    uint tgid [[threadgroup_position_in_grid]],
779    uint tid [[thread_index_in_threadgroup]],
780    uint simd_lane [[thread_index_in_simdgroup]],
781    uint simd_id [[simdgroup_index_in_threadgroup]])
782{
783    uint head = tgid;
784    if (head >= num_heads) return;
785    uint kv_head = head / (num_heads / num_kv_heads);
786
787    uint q_off = head * head_dim;
788
789    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
790    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
791    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
792    // most generation steps have short effective context.
793    threadgroup float scores[2048];  // max seq_len for generation phase
794
795    for (uint s = simd_id; s < seq_len; s += 8) {
796        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
797        float dot = 0.0;
798        for (uint d = simd_lane; d < head_dim; d += 32) {
799            dot += q[q_off + d] * k_cache[k_off + d];
800        }
801        dot = simd_sum(dot);
802        if (simd_lane == 0) {
803            scores[s] = dot * fast::rsqrt(float(head_dim));
804        }
805    }
806    threadgroup_barrier(mem_flags::mem_threadgroup);
807
808    // Step 2: Softmax over scores (cooperative)
809    // Find max
810    float local_max = -INFINITY;
811    for (uint s = tid; s < seq_len; s += 256) {
812        local_max = max(local_max, scores[s]);
813    }
814    local_max = simd_max(local_max);
815    threadgroup float shared_max[8];
816    if (simd_lane == 0) shared_max[simd_id] = local_max;
817    threadgroup_barrier(mem_flags::mem_threadgroup);
818    if (tid == 0) {
819        float m = shared_max[0];
820        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
821        shared_max[0] = m;
822    }
823    threadgroup_barrier(mem_flags::mem_threadgroup);
824    float max_val = shared_max[0];
825
826    // Exp and sum
827    float local_sum = 0.0;
828    for (uint s = tid; s < seq_len; s += 256) {
829        scores[s] = fast::exp(scores[s] - max_val);
830        local_sum += scores[s];
831    }
832    local_sum = simd_sum(local_sum);
833    threadgroup float shared_sum[8];
834    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
835    threadgroup_barrier(mem_flags::mem_threadgroup);
836    if (tid == 0) {
837        float total = 0.0;
838        for (uint i = 0; i < 8; i++) total += shared_sum[i];
839        shared_sum[0] = 1.0 / total;
840    }
841    threadgroup_barrier(mem_flags::mem_threadgroup);
842    float inv_sum = shared_sum[0];
843
844    for (uint s = tid; s < seq_len; s += 256) {
845        scores[s] *= inv_sum;
846    }
847    threadgroup_barrier(mem_flags::mem_threadgroup);
848
849    // Step 3: Weighted sum of V: output = scores · V
850    // Each thread handles a range of head_dim dimensions.
851    // Process 4 sequence positions at a time for better ILP and reduced
852    // loop overhead (float4 score gather, 4 V loads per iteration).
853    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
854    uint v_stride = num_kv_heads * head_dim;
855    for (uint d = tid; d < head_dim; d += 256) {
856        float acc = 0.0;
857        uint v_base = kv_head * head_dim + d;
858        for (uint s = 0; s < seq_len4; s += 4) {
859            float sc0 = scores[s];
860            float sc1 = scores[s + 1];
861            float sc2 = scores[s + 2];
862            float sc3 = scores[s + 3];
863            acc += sc0 * v_cache[s * v_stride + v_base]
864                 + sc1 * v_cache[(s+1) * v_stride + v_base]
865                 + sc2 * v_cache[(s+2) * v_stride + v_base]
866                 + sc3 * v_cache[(s+3) * v_stride + v_base];
867        }
868        for (uint s = seq_len4; s < seq_len; s++) {
869            acc += scores[s] * v_cache[s * v_stride + v_base];
870        }
871        output[q_off + d] = acc;
872    }
873}
874
875// ── Batched prefill kernels ────────────────────────────────────────────
876// These kernels process M input vectors against the same weight matrix
877// in a single dispatch, converting mat-vec into mat-mat for better GPU
878// utilization during prompt prefill.
879
880// ── rms_norm_batch ─────────────────────────────────────────────────────
881// RMS normalization for a batch of vectors.
882// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
883// Grid: M threadgroups (one per token).
884kernel void rms_norm_batch(
885    device const float* input   [[buffer(0)]],  // [M, n]
886    device const float* weight  [[buffer(1)]],  // [n]
887    device float* output        [[buffer(2)]],  // [M, n]
888    constant uint& n            [[buffer(3)]],
889    constant float& eps         [[buffer(4)]],
890    constant uint& num_tokens   [[buffer(5)]],
891    uint tgid [[threadgroup_position_in_grid]],
892    uint tid [[thread_index_in_threadgroup]])
893{
894    if (tgid >= num_tokens) return;
895
896    uint base = tgid * n;
897
898    float sum_sq = 0.0f;
899    for (uint i = tid; i < n; i += 256) {
900        float v = input[base + i];
901        sum_sq += v * v;
902    }
903
904    sum_sq = simd_sum(sum_sq);
905
906    threadgroup float shared[8];
907    uint simd_id = tid / 32;
908    uint simd_lane = tid % 32;
909    if (simd_lane == 0) {
910        shared[simd_id] = sum_sq;
911    }
912    threadgroup_barrier(mem_flags::mem_threadgroup);
913
914    if (tid == 0) {
915        float total = 0.0f;
916        for (uint i = 0; i < 8; i++) {
917            total += shared[i];
918        }
919        shared[0] = fast::rsqrt(total / float(n) + eps);
920    }
921    threadgroup_barrier(mem_flags::mem_threadgroup);
922
923    float inv_rms = shared[0];
924
925    for (uint i = tid; i < n; i += 256) {
926        output[base + i] = input[base + i] * inv_rms * weight[i];
927    }
928}
929
930// ── rope_batch ─────────────────────────────────────────────────────────
931// Rotary Position Embedding for a batch of vectors with different positions.
932// data layout: [M, num_heads * head_dim], positions: [M]
933// Each thread handles one (token, head, pair) combination.
934kernel void rope_batch(
935    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
936    constant uint& num_heads     [[buffer(1)]],
937    constant uint& head_dim      [[buffer(2)]],
938    device const uint* positions  [[buffer(3)]],  // [M] position per token
939    constant float& theta        [[buffer(4)]],
940    constant uint& num_tokens    [[buffer(5)]],
941    uint id [[thread_position_in_grid]])
942{
943    uint half_dim = head_dim / 2;
944    uint pairs_per_token = num_heads * half_dim;
945    uint total = num_tokens * pairs_per_token;
946    if (id >= total) return;
947
948    uint token = id / pairs_per_token;
949    uint rem = id % pairs_per_token;
950    uint h = rem / half_dim;
951    uint i = rem % half_dim;
952    uint off = token * (num_heads * head_dim) + h * head_dim;
953
954    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
955    float angle = float(positions[token]) * freq;
956    float c = cos(angle);
957    float s = sin(angle);
958
959    float x0 = data[off + 2 * i];
960    float x1 = data[off + 2 * i + 1];
961    data[off + 2 * i]     = x0 * c - x1 * s;
962    data[off + 2 * i + 1] = x0 * s + x1 * c;
963}
964
965// ── silu_mul_fused_batch ───────────────────────────────────────────────
966// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
967// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
968kernel void silu_mul_fused_batch(
969    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
970    device float* output        [[buffer(1)]],  // [M, n]
971    constant uint& n            [[buffer(2)]],
972    constant uint& num_tokens   [[buffer(3)]],
973    uint id [[thread_position_in_grid]])
974{
975    uint total = num_tokens * n;
976    if (id >= total) return;
977    uint token = id / n;
978    uint i = id % n;
979    uint gu_base = token * 2 * n;
980    float g = gate_up[gu_base + i];
981    float u = gate_up[gu_base + n + i];
982    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
983}
984
985// ── add_inplace_batch ──────────────────────────────────────────────────
986// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
987kernel void add_inplace_batch(
988    device float* a        [[buffer(0)]],  // [M * n]
989    device const float* b  [[buffer(1)]],  // [M * n]
990    constant uint& total   [[buffer(2)]],  // M * n
991    uint id [[thread_position_in_grid]])
992{
993    if (id >= total) return;
994    a[id] += b[id];
995}
996
997// ── copy_embedding_batch ───────────────────────────────────────────────
998// Copy M embedding rows from embedding table to a contiguous batch buffer.
999// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1000kernel void copy_embedding_batch(
1001    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1002    device float* output        [[buffer(1)]],  // [M, dim]
1003    device const uint* tokens   [[buffer(2)]],  // [M]
1004    constant uint& dim          [[buffer(3)]],
1005    constant uint& num_tokens   [[buffer(4)]],
1006    uint id [[thread_position_in_grid]])
1007{
1008    uint total = num_tokens * dim;
1009    if (id >= total) return;
1010    uint token_idx = id / dim;
1011    uint d = id % dim;
1012    output[id] = embed[tokens[token_idx] * dim + d];
1013}
1014
1015// ── matmul_vec_batch ───────────────────────────────────────────────────
1016// Batched matrix-vector multiply: process M input vectors against the same
1017// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1018// Each threadgroup handles one (token, row_group) pair.
1019kernel void matmul_vec_batch(
1020    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1021    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1022    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1023    constant uint& num_tokens   [[buffer(3)]],  // M
1024    constant uint& rows         [[buffer(4)]],
1025    constant uint& cols         [[buffer(5)]],
1026    uint tgid [[threadgroup_position_in_grid]],
1027    uint tid [[thread_index_in_threadgroup]],
1028    uint simd_lane [[thread_index_in_simdgroup]],
1029    uint simd_id [[simdgroup_index_in_threadgroup]])
1030{
1031    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1032    uint token = tgid / row_tgs;
1033    uint tg_in_token = tgid % row_tgs;
1034    if (token >= num_tokens) return;
1035
1036    // Load this token's input vector into shared memory
1037    threadgroup float vec_tile[VEC_TILE_SIZE];
1038    device const float* input = inputs + token * cols;
1039    for (uint i = tid; i < cols; i += 256) {
1040        vec_tile[i] = input[i];
1041    }
1042    threadgroup_barrier(mem_flags::mem_threadgroup);
1043
1044    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1045    if (row_base >= rows) return;
1046
1047    uint base0 = row_base * cols;
1048    uint base1 = (row_base + 1) * cols;
1049    uint base2 = (row_base + 2) * cols;
1050    uint base3 = (row_base + 3) * cols;
1051    uint base4 = (row_base + 4) * cols;
1052    uint base5 = (row_base + 5) * cols;
1053    uint base6 = (row_base + 6) * cols;
1054    uint base7 = (row_base + 7) * cols;
1055
1056    uint cols_vec4 = cols & ~127u;
1057    float4 sum4_0 = float4(0.0f);
1058    float4 sum4_1 = float4(0.0f);
1059    float4 sum4_2 = float4(0.0f);
1060    float4 sum4_3 = float4(0.0f);
1061    float4 sum4_4 = float4(0.0f);
1062    float4 sum4_5 = float4(0.0f);
1063    float4 sum4_6 = float4(0.0f);
1064    float4 sum4_7 = float4(0.0f);
1065
1066    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1067        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1068        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1069        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1070        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1071        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1072        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1073        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1074        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1075        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1076    }
1077
1078    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1079    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1080    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1081    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1082    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1083    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1084    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1085    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1086
1087    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1088        float vv = vec_tile[j];
1089        sum0 += matrix[base0 + j] * vv;
1090        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1091        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1092        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1093        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1094        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1095        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1096        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1097    }
1098
1099    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1100    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1101    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1102    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1103
1104    device float* output = outputs + token * rows;
1105    if (simd_lane == 0) {
1106        if (row_base     < rows) output[row_base]     = sum0;
1107        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1108        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1109        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1110        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1111        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1112        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1113        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1114    }
1115}
1116
1117// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1118// Batched Q8_0 matrix-vector multiply for M input vectors.
1119// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1120kernel void matmul_vec_q8_batch(
1121    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1122    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1123    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1124    constant uint& num_tokens    [[buffer(3)]],  // M
1125    constant uint& rows          [[buffer(4)]],
1126    constant uint& cols          [[buffer(5)]],
1127    uint tgid [[threadgroup_position_in_grid]],
1128    uint tid [[thread_index_in_threadgroup]],
1129    uint simd_lane [[thread_index_in_simdgroup]],
1130    uint simd_id [[simdgroup_index_in_threadgroup]])
1131{
1132    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1133    uint token = tgid / row_tgs;
1134    uint tg_in_token = tgid % row_tgs;
1135    if (token >= num_tokens) return;
1136
1137    // Load this token's input vector into shared memory
1138    threadgroup float vec_tile[VEC_TILE_SIZE];
1139    device const float* input = inputs + token * cols;
1140    for (uint i = tid; i < cols; i += 256) {
1141        vec_tile[i] = input[i];
1142    }
1143    threadgroup_barrier(mem_flags::mem_threadgroup);
1144
1145    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1146    if (row_base >= rows) return;
1147
1148    uint blocks_per_row = cols / 32;
1149    uint row_bytes = blocks_per_row * 34;
1150
1151    device const uchar* r0 = matrix + row_base * row_bytes;
1152    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1153    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1154    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1155
1156    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1157
1158    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1159        uint bb = blk * 34;
1160        uint vb = blk * 32;
1161
1162        float sc0 = float(*(device const half*)(r0 + bb));
1163        float sc1 = float(*(device const half*)(r1 + bb));
1164        float sc2 = float(*(device const half*)(r2 + bb));
1165        float sc3 = float(*(device const half*)(r3 + bb));
1166
1167        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1168        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1169        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1170        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1171        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1172        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1173
1174        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1175        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1176        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1177        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1178        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1179        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1180        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1181        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1182
1183        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1184            short4 _s = short4(SHORT4); \
1185            char2 _a = as_type<char2>(_s.x); \
1186            char2 _b = as_type<char2>(_s.y); \
1187            char2 _c = as_type<char2>(_s.z); \
1188            char2 _d = as_type<char2>(_s.w); \
1189            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1190            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1191        }
1192
1193        float4 f0, f1;
1194        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1195
1196        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1197        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1198        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1199        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1200
1201        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1202        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1203        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1204        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1205
1206        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1207        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1208        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1209        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1210
1211        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1212        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1213        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1214        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1215
1216        #undef Q8_UNPACK8
1217
1218        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1219    }
1220
1221    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1222    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1223
1224    device float* output = outputs + token * rows;
1225    if (simd_lane == 0) {
1226        if (row_base     < rows) output[row_base]     = sum0;
1227        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1228        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1229        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1230    }
1231}
1232
1233// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1234// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1235// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1236// the Q8_0 weight blocks are fetched once from device memory and reused for
1237// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1238// per-token dispatch).
1239//
1240// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1241// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1242// with simd_sum.  Token vectors are read directly from device memory inside
1243// the block loop (not cached in shared memory) so intermediate_size up to
1244// 8192 fits without spilling threadgroup memory.
1245constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1246
1247kernel void matmul_q8_gemm_batch(
1248    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1249    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1250    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1251    constant uint& num_tokens    [[buffer(3)]],  // M
1252    constant uint& rows          [[buffer(4)]],
1253    constant uint& cols          [[buffer(5)]],
1254    uint2 tgid [[threadgroup_position_in_grid]],
1255    uint tid [[thread_index_in_threadgroup]],
1256    uint simd_lane [[thread_index_in_simdgroup]],
1257    uint simd_id [[simdgroup_index_in_threadgroup]])
1258{
1259    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1260    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1261    if (row_base >= rows || tok_base >= num_tokens) return;
1262
1263    // How many tokens in this tile are valid?
1264    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1265
1266    uint blocks_per_row = cols / 32;
1267    uint row_bytes = blocks_per_row * 34;
1268
1269    device const uchar* r0 = matrix + row_base * row_bytes;
1270    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1271    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1272    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1273
1274    // Accumulators: 4 tokens × 4 rows per simdgroup.
1275    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1276    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1277    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1278    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1279
1280    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1281        uint bb = blk * 34;
1282        uint vb = blk * 32;
1283
1284        // ── Load weight data ONCE per block (reused across all tokens) ──
1285        float sc0 = float(*(device const half*)(r0 + bb));
1286        float sc1 = float(*(device const half*)(r1 + bb));
1287        float sc2 = float(*(device const half*)(r2 + bb));
1288        float sc3 = float(*(device const half*)(r3 + bb));
1289
1290        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1291        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1292        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1293        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1294
1295        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1296            short4 _s = short4(SHORT4); \
1297            char2 _a = as_type<char2>(_s.x); \
1298            char2 _b = as_type<char2>(_s.y); \
1299            char2 _c = as_type<char2>(_s.z); \
1300            char2 _d = as_type<char2>(_s.w); \
1301            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1302            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1303        }
1304
1305        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1306        // registers for the duration of the block and are dotted against
1307        // every token's vector tile.
1308        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1309        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1310        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1311        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1312
1313        Q8_UNPACK8(d0[0], w0_0, w0_1);
1314        Q8_UNPACK8(d0[1], w0_2, w0_3);
1315        Q8_UNPACK8(d0[2], w0_4, w0_5);
1316        Q8_UNPACK8(d0[3], w0_6, w0_7);
1317
1318        Q8_UNPACK8(d1[0], w1_0, w1_1);
1319        Q8_UNPACK8(d1[1], w1_2, w1_3);
1320        Q8_UNPACK8(d1[2], w1_4, w1_5);
1321        Q8_UNPACK8(d1[3], w1_6, w1_7);
1322
1323        Q8_UNPACK8(d2[0], w2_0, w2_1);
1324        Q8_UNPACK8(d2[1], w2_2, w2_3);
1325        Q8_UNPACK8(d2[2], w2_4, w2_5);
1326        Q8_UNPACK8(d2[3], w2_6, w2_7);
1327
1328        Q8_UNPACK8(d3[0], w3_0, w3_1);
1329        Q8_UNPACK8(d3[1], w3_2, w3_3);
1330        Q8_UNPACK8(d3[2], w3_4, w3_5);
1331        Q8_UNPACK8(d3[3], w3_6, w3_7);
1332
1333        #undef Q8_UNPACK8
1334
1335        // ── For each token, read vector and accumulate against shared weights ──
1336        // Token 0 (always valid: tok_count >= 1).
1337        {
1338            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1339            float4 v0 = *(device const float4*)(a0);
1340            float4 v1 = *(device const float4*)(a0 + 4);
1341            float4 v2 = *(device const float4*)(a0 + 8);
1342            float4 v3 = *(device const float4*)(a0 + 12);
1343            float4 v4 = *(device const float4*)(a0 + 16);
1344            float4 v5 = *(device const float4*)(a0 + 20);
1345            float4 v6 = *(device const float4*)(a0 + 24);
1346            float4 v7 = *(device const float4*)(a0 + 28);
1347            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1348                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1349            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1350                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1351            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1352                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1353            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1354                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1355            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1356        }
1357        // Token 1
1358        if (tok_count > 1) {
1359            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1360            float4 v0 = *(device const float4*)(a1);
1361            float4 v1 = *(device const float4*)(a1 + 4);
1362            float4 v2 = *(device const float4*)(a1 + 8);
1363            float4 v3 = *(device const float4*)(a1 + 12);
1364            float4 v4 = *(device const float4*)(a1 + 16);
1365            float4 v5 = *(device const float4*)(a1 + 20);
1366            float4 v6 = *(device const float4*)(a1 + 24);
1367            float4 v7 = *(device const float4*)(a1 + 28);
1368            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1369                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1370            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1371                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1372            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1373                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1374            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1375                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1376            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1377        }
1378        // Token 2
1379        if (tok_count > 2) {
1380            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1381            float4 v0 = *(device const float4*)(a2);
1382            float4 v1 = *(device const float4*)(a2 + 4);
1383            float4 v2 = *(device const float4*)(a2 + 8);
1384            float4 v3 = *(device const float4*)(a2 + 12);
1385            float4 v4 = *(device const float4*)(a2 + 16);
1386            float4 v5 = *(device const float4*)(a2 + 20);
1387            float4 v6 = *(device const float4*)(a2 + 24);
1388            float4 v7 = *(device const float4*)(a2 + 28);
1389            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1390                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1391            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1392                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1393            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1394                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1395            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1396                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1397            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1398        }
1399        // Token 3
1400        if (tok_count > 3) {
1401            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1402            float4 v0 = *(device const float4*)(a3);
1403            float4 v1 = *(device const float4*)(a3 + 4);
1404            float4 v2 = *(device const float4*)(a3 + 8);
1405            float4 v3 = *(device const float4*)(a3 + 12);
1406            float4 v4 = *(device const float4*)(a3 + 16);
1407            float4 v5 = *(device const float4*)(a3 + 20);
1408            float4 v6 = *(device const float4*)(a3 + 24);
1409            float4 v7 = *(device const float4*)(a3 + 28);
1410            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1411                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1412            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1413                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1414            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1415                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1416            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1417                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1418            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1419        }
1420    }
1421
1422    // simdgroup reduction
1423    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1424    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1425    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1426    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1427
1428    if (simd_lane == 0) {
1429        device float* o0 = outputs + (tok_base + 0) * rows;
1430        if (row_base     < rows) o0[row_base]     = s00;
1431        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1432        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1433        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1434
1435        if (tok_count > 1) {
1436            device float* o1 = outputs + (tok_base + 1) * rows;
1437            if (row_base     < rows) o1[row_base]     = s10;
1438            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1439            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1440            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1441        }
1442        if (tok_count > 2) {
1443            device float* o2 = outputs + (tok_base + 2) * rows;
1444            if (row_base     < rows) o2[row_base]     = s20;
1445            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1446            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1447            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1448        }
1449        if (tok_count > 3) {
1450            device float* o3 = outputs + (tok_base + 3) * rows;
1451            if (row_base     < rows) o3[row_base]     = s30;
1452            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1453            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1454            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1455        }
1456    }
1457}
1458
1459// ── matmul_q8_mma ──────────────────────────────────────────────────────
1460// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1461// simdgroup_matrix tiles (simdgroup_multiply_accumulate).  This dispatches
1462// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1463// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1464//
1465// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1466// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1467// simdgroup_matrix<float, 8, 8> accumulator.  Weight bytes are cooperatively
1468// dequantized into threadgroup memory once per block and reused by all
1469// simdgroups in the tile.
1470//
1471// Assumptions (verified in the dispatch helper, falls back otherwise):
1472//   * cols  % 32 == 0   (one Q8_0 block per K chunk)
1473//   * rows  % 16 == 0   (tile-aligned; true for all supported architectures)
1474//   * num_tokens may be any value; partial row at the tile boundary is handled
1475//     via a scratch copy path.
1476constant constexpr uint MMA_TOK_TILE = 16;
1477constant constexpr uint MMA_ROW_TILE = 16;
1478
1479kernel void matmul_q8_mma(
1480    device const uchar* matrix   [[buffer(0)]],  // Q8_0 [rows, cols/32 * 34]
1481    device const float* inputs   [[buffer(1)]],  // [M, cols]
1482    device float* outputs        [[buffer(2)]],  // [M, rows]
1483    constant uint& num_tokens    [[buffer(3)]],
1484    constant uint& rows          [[buffer(4)]],
1485    constant uint& cols          [[buffer(5)]],
1486    uint2 tgid [[threadgroup_position_in_grid]],
1487    uint tid [[thread_index_in_threadgroup]],
1488    uint simd_id [[simdgroup_index_in_threadgroup]])
1489{
1490    uint row_base = tgid.x * MMA_ROW_TILE;
1491    uint tok_base = tgid.y * MMA_TOK_TILE;
1492    if (row_base >= rows || tok_base >= num_tokens) return;
1493
1494    // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1495    threadgroup float w_tile[MMA_ROW_TILE * 32];
1496    threadgroup float t_tile[MMA_TOK_TILE * 32];
1497
1498    // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1499    uint sg_tok_base = (simd_id / 2) * 8;  // row within output tile (token dim)
1500    uint sg_row_base = (simd_id % 2) * 8;  // col within output tile (row dim)
1501
1502    simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1503
1504    uint blocks_per_row = cols / 32;
1505    uint row_bytes = blocks_per_row * 34;
1506
1507    for (uint blk = 0; blk < blocks_per_row; blk++) {
1508        // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1509        // 512 floats / 128 threads = 4 floats per thread.
1510        {
1511            uint base = tid * 4;
1512            for (uint ii = 0; ii < 4; ii++) {
1513                uint idx = base + ii;
1514                uint r = idx / 32;
1515                uint k = idx % 32;
1516                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1517                float sc = float(*(device const half*)rp);
1518                int ival = int(*(device const int8_t*)(rp + 2 + k));
1519                w_tile[r * 32 + k] = float(ival) * sc;
1520            }
1521        }
1522
1523        // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1524        {
1525            uint base = tid * 4;
1526            for (uint ii = 0; ii < 4; ii++) {
1527                uint idx = base + ii;
1528                uint m = idx / 32;
1529                uint k = idx % 32;
1530                uint tok = tok_base + m;
1531                t_tile[m * 32 + k] = (tok < num_tokens)
1532                    ? inputs[tok * cols + blk * 32 + k]
1533                    : 0.0f;
1534            }
1535        }
1536
1537        threadgroup_barrier(mem_flags::mem_threadgroup);
1538
1539        // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1540        // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k]  (M×K, no transpose)
1541        // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k]  (loaded transposed → K×R)
1542        // C[m, r] += A[m, k] * B[k, r]
1543        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1544            simdgroup_matrix<float, 8, 8> A, B;
1545            simdgroup_load(A,
1546                t_tile + sg_tok_base * 32 + k_sub * 8,
1547                32,
1548                ulong2(0, 0),
1549                false);
1550            simdgroup_load(B,
1551                w_tile + sg_row_base * 32 + k_sub * 8,
1552                32,
1553                ulong2(0, 0),
1554                true);
1555            simdgroup_multiply_accumulate(C, A, B, C);
1556        }
1557
1558        threadgroup_barrier(mem_flags::mem_threadgroup);
1559    }
1560
1561    // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1562    // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1563    uint out_tok = tok_base + sg_tok_base;
1564    uint out_row = row_base + sg_row_base;
1565    bool full_tok = (out_tok + 8 <= num_tokens);
1566    if (full_tok) {
1567        // Fast path: entire 8×8 sub-tile is in-bounds.
1568        simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1569    } else if (out_tok < num_tokens) {
1570        // Partial row at the last token tile: stage in per-simdgroup scratch
1571        // and scalar-copy the valid rows.
1572        threadgroup float scratch[4 * 64];
1573        simdgroup_store(C, scratch + simd_id * 64, 8);
1574        simdgroup_barrier(mem_flags::mem_threadgroup);
1575        uint lane = tid % 32;
1576        if (lane == 0) {
1577            uint valid = num_tokens - out_tok;  // 1..7
1578            for (uint m = 0; m < valid; m++) {
1579                device float* dst = outputs + (out_tok + m) * rows + out_row;
1580                threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1581                dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1582                dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1583            }
1584        }
1585    }
1586}
1587
1588// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1589// Larger-tile variant of matmul_q8_mma for long-context prefill.
1590//
1591// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1592// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1593// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1594//
1595//     simd_id = 2*sg_tok_idx + sg_row_half          (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1596//     output sub-tiles (tok, row):
1597//         (sg_tok_idx*8, sg_row_half*16 +  0)  -> C_a
1598//         (sg_tok_idx*8, sg_row_half*16 +  8)  -> C_b
1599//
1600// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1601// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1602// kernel — and halves the number of threadgroups vs the 16×16 tile.
1603//
1604// Assumptions (verified in dispatch helper, fallback otherwise):
1605//   * cols % 32 == 0
1606//   * rows % 32 == 0
1607constant constexpr uint MMA32_TOK_TILE = 32;
1608constant constexpr uint MMA32_ROW_TILE = 32;
1609
1610kernel void matmul_q8_mma32(
1611    device const uchar* matrix   [[buffer(0)]],
1612    device const float* inputs   [[buffer(1)]],
1613    device float* outputs        [[buffer(2)]],
1614    constant uint& num_tokens    [[buffer(3)]],
1615    constant uint& rows          [[buffer(4)]],
1616    constant uint& cols          [[buffer(5)]],
1617    uint2 tgid [[threadgroup_position_in_grid]],
1618    uint tid [[thread_index_in_threadgroup]],
1619    uint simd_id [[simdgroup_index_in_threadgroup]])
1620{
1621    uint row_base = tgid.x * MMA32_ROW_TILE;
1622    uint tok_base = tgid.y * MMA32_TOK_TILE;
1623    if (row_base >= rows || tok_base >= num_tokens) return;
1624
1625    // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1626    threadgroup float w_tile[MMA32_ROW_TILE * 32];
1627    threadgroup float t_tile[MMA32_TOK_TILE * 32];
1628
1629    uint sg_tok_idx  = simd_id / 2;      // 0..3
1630    uint sg_row_half = simd_id % 2;      // 0..1
1631    uint sg_tok_base = sg_tok_idx * 8;
1632    uint sg_row_base_a = sg_row_half * 16 + 0;
1633    uint sg_row_base_b = sg_row_half * 16 + 8;
1634
1635    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1636    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1637
1638    uint blocks_per_row = cols / 32;
1639    uint row_bytes = blocks_per_row * 34;
1640
1641    for (uint blk = 0; blk < blocks_per_row; blk++) {
1642        // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1643        {
1644            uint base = tid * 4;
1645            for (uint ii = 0; ii < 4; ii++) {
1646                uint idx = base + ii;
1647                uint r = idx / 32;
1648                uint k = idx % 32;
1649                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1650                float sc = float(*(device const half*)rp);
1651                int ival = int(*(device const int8_t*)(rp + 2 + k));
1652                w_tile[r * 32 + k] = float(ival) * sc;
1653            }
1654        }
1655
1656        // Cooperative token tile load.
1657        {
1658            uint base = tid * 4;
1659            for (uint ii = 0; ii < 4; ii++) {
1660                uint idx = base + ii;
1661                uint m = idx / 32;
1662                uint k = idx % 32;
1663                uint tok = tok_base + m;
1664                t_tile[m * 32 + k] = (tok < num_tokens)
1665                    ? inputs[tok * cols + blk * 32 + k]
1666                    : 0.0f;
1667            }
1668        }
1669
1670        threadgroup_barrier(mem_flags::mem_threadgroup);
1671
1672        // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1673        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1674            simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1675            simdgroup_load(A,
1676                t_tile + sg_tok_base * 32 + k_sub * 8,
1677                32,
1678                ulong2(0, 0),
1679                false);
1680            simdgroup_load(B_a,
1681                w_tile + sg_row_base_a * 32 + k_sub * 8,
1682                32,
1683                ulong2(0, 0),
1684                true);
1685            simdgroup_load(B_b,
1686                w_tile + sg_row_base_b * 32 + k_sub * 8,
1687                32,
1688                ulong2(0, 0),
1689                true);
1690            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1691            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1692        }
1693
1694        threadgroup_barrier(mem_flags::mem_threadgroup);
1695    }
1696
1697    // Store both 8×8 accumulators.  rows is always MMA32_ROW_TILE-aligned
1698    // (verified in dispatch), so full simdgroup_store is safe for the row
1699    // dimension; only the last token tile may be partial.
1700    uint out_tok = tok_base + sg_tok_base;
1701    uint out_row_a = row_base + sg_row_base_a;
1702    uint out_row_b = row_base + sg_row_base_b;
1703    bool full_tok = (out_tok + 8 <= num_tokens);
1704    if (full_tok) {
1705        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1706        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1707    } else if (out_tok < num_tokens) {
1708        threadgroup float scratch[8 * 2 * 64];  // 8 simdgroups × 2 accs × 64 floats
1709        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1710        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1711        simdgroup_barrier(mem_flags::mem_threadgroup);
1712        uint lane = tid % 32;
1713        if (lane == 0) {
1714            uint valid = num_tokens - out_tok;  // 1..7
1715            for (uint m = 0; m < valid; m++) {
1716                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1717                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1718                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1719                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1720                for (uint j = 0; j < 8; j++) {
1721                    dst_a[j] = src_a[j];
1722                    dst_b[j] = src_b[j];
1723                }
1724            }
1725        }
1726    }
1727}
1728
1729// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1730// FP16 threadgroup-tile variant of matmul_q8_mma32.
1731//
1732// Stores dequantized weights and token inputs as `half` in threadgroup
1733// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1734// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1735// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1736// representation preserves the full quantized dynamic range.  Token
1737// activations stay numerically safe because the subsequent
1738// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1739//
1740// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1741// accumulators each.  Primary win vs mma32 is occupancy at moderate
1742// prefill lengths where the GPU is wave-starved.
1743kernel void matmul_q8_mma32_h(
1744    device const uchar* matrix   [[buffer(0)]],
1745    device const float* inputs   [[buffer(1)]],
1746    device float* outputs        [[buffer(2)]],
1747    constant uint& num_tokens    [[buffer(3)]],
1748    constant uint& rows          [[buffer(4)]],
1749    constant uint& cols          [[buffer(5)]],
1750    uint2 tgid [[threadgroup_position_in_grid]],
1751    uint tid [[thread_index_in_threadgroup]],
1752    uint simd_id [[simdgroup_index_in_threadgroup]])
1753{
1754    uint row_base = tgid.x * MMA32_ROW_TILE;
1755    uint tok_base = tgid.y * MMA32_TOK_TILE;
1756    if (row_base >= rows || tok_base >= num_tokens) return;
1757
1758    // 32×32 half tiles — 2 KB each, 4 KB total.
1759    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1760    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1761
1762    uint sg_tok_idx  = simd_id / 2;
1763    uint sg_row_half = simd_id % 2;
1764    uint sg_tok_base = sg_tok_idx * 8;
1765    uint sg_row_base_a = sg_row_half * 16 + 0;
1766    uint sg_row_base_b = sg_row_half * 16 + 8;
1767
1768    simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1769    simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1770
1771    uint blocks_per_row = cols / 32;
1772    uint row_bytes = blocks_per_row * 34;
1773
1774    for (uint blk = 0; blk < blocks_per_row; blk++) {
1775        // Cooperative weight dequantization to FP16.
1776        {
1777            uint base = tid * 4;
1778            for (uint ii = 0; ii < 4; ii++) {
1779                uint idx = base + ii;
1780                uint r = idx / 32;
1781                uint k = idx % 32;
1782                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1783                float sc = float(*(device const half*)rp);
1784                int ival = int(*(device const int8_t*)(rp + 2 + k));
1785                w_tile[r * 32 + k] = half(float(ival) * sc);
1786            }
1787        }
1788
1789        // Cooperative token tile load (f32 → f16 narrowing).
1790        {
1791            uint base = tid * 4;
1792            for (uint ii = 0; ii < 4; ii++) {
1793                uint idx = base + ii;
1794                uint m = idx / 32;
1795                uint k = idx % 32;
1796                uint tok = tok_base + m;
1797                t_tile[m * 32 + k] = (tok < num_tokens)
1798                    ? half(inputs[tok * cols + blk * 32 + k])
1799                    : half(0);
1800            }
1801        }
1802
1803        threadgroup_barrier(mem_flags::mem_threadgroup);
1804
1805        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1806            simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1807            simdgroup_load(A,
1808                t_tile + sg_tok_base * 32 + k_sub * 8,
1809                32,
1810                ulong2(0, 0),
1811                false);
1812            simdgroup_load(B_a,
1813                w_tile + sg_row_base_a * 32 + k_sub * 8,
1814                32,
1815                ulong2(0, 0),
1816                true);
1817            simdgroup_load(B_b,
1818                w_tile + sg_row_base_b * 32 + k_sub * 8,
1819                32,
1820                ulong2(0, 0),
1821                true);
1822            simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1823            simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1824        }
1825
1826        threadgroup_barrier(mem_flags::mem_threadgroup);
1827    }
1828
1829    uint out_tok = tok_base + sg_tok_base;
1830    uint out_row_a = row_base + sg_row_base_a;
1831    uint out_row_b = row_base + sg_row_base_b;
1832    bool full_tok = (out_tok + 8 <= num_tokens);
1833    if (full_tok) {
1834        simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1835        simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1836    } else if (out_tok < num_tokens) {
1837        threadgroup float scratch[8 * 2 * 64];
1838        simdgroup_store(C_a, scratch + simd_id * 128, 8);
1839        simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1840        simdgroup_barrier(mem_flags::mem_threadgroup);
1841        uint lane = tid % 32;
1842        if (lane == 0) {
1843            uint valid = num_tokens - out_tok;
1844            for (uint m = 0; m < valid; m++) {
1845                device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1846                device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1847                threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1848                threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1849                for (uint j = 0; j < 8; j++) {
1850                    dst_a[j] = src_a[j];
1851                    dst_b[j] = src_b[j];
1852                }
1853            }
1854        }
1855    }
1856}
1857
1858// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1859// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1860//
1861// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1862// 4 simdgroups × **2×2 grid** of 8×8 accumulators each.  Per simdgroup:
1863//   C_00 (tok 0..8, row 0..8)    C_01 (tok 0..8, row 8..16)
1864//   C_10 (tok 8..16, row 0..8)   C_11 (tok 8..16, row 8..16)
1865// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1866//
1867// Per K_sub iteration: load two A fragments and two B fragments, then run
1868// **four** MMA instructions reusing A_top with both B's and A_bot with
1869// both B's.  That's double the FLOP-per-simdgroup-load compared to the
1870// 2-accumulator kernel and halves the thread count per threadgroup (128
1871// threads), which often improves occupancy on Apple GPUs where the
1872// concurrent-thread budget is the tighter limit than shared-memory size.
1873kernel void matmul_q8_mma32_h4(
1874    device const uchar* matrix   [[buffer(0)]],
1875    device const float* inputs   [[buffer(1)]],
1876    device float* outputs        [[buffer(2)]],
1877    constant uint& num_tokens    [[buffer(3)]],
1878    constant uint& rows          [[buffer(4)]],
1879    constant uint& cols          [[buffer(5)]],
1880    uint2 tgid [[threadgroup_position_in_grid]],
1881    uint tid [[thread_index_in_threadgroup]],
1882    uint simd_id [[simdgroup_index_in_threadgroup]])
1883{
1884    uint row_base = tgid.x * MMA32_ROW_TILE;
1885    uint tok_base = tgid.y * MMA32_TOK_TILE;
1886    if (row_base >= rows || tok_base >= num_tokens) return;
1887
1888    // 32×32 FP16 tiles, 4 KB total.
1889    threadgroup half w_tile[MMA32_ROW_TILE * 32];
1890    threadgroup half t_tile[MMA32_TOK_TILE * 32];
1891
1892    // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1893    uint sg_tok_q = simd_id / 2;   // 0..1
1894    uint sg_row_q = simd_id % 2;   // 0..1
1895    uint sg_tok_base = sg_tok_q * 16;
1896    uint sg_row_base = sg_row_q * 16;
1897
1898    simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1899    simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1900    simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1901    simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1902
1903    uint blocks_per_row = cols / 32;
1904    uint row_bytes = blocks_per_row * 34;
1905
1906    for (uint blk = 0; blk < blocks_per_row; blk++) {
1907        // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1908        {
1909            uint base = tid * 8;
1910            for (uint ii = 0; ii < 8; ii++) {
1911                uint idx = base + ii;
1912                uint r = idx / 32;
1913                uint k = idx % 32;
1914                device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1915                float sc = float(*(device const half*)rp);
1916                int ival = int(*(device const int8_t*)(rp + 2 + k));
1917                w_tile[r * 32 + k] = half(float(ival) * sc);
1918            }
1919        }
1920
1921        // Cooperative token tile load.
1922        {
1923            uint base = tid * 8;
1924            for (uint ii = 0; ii < 8; ii++) {
1925                uint idx = base + ii;
1926                uint m = idx / 32;
1927                uint k = idx % 32;
1928                uint tok = tok_base + m;
1929                t_tile[m * 32 + k] = (tok < num_tokens)
1930                    ? half(inputs[tok * cols + blk * 32 + k])
1931                    : half(0);
1932            }
1933        }
1934
1935        threadgroup_barrier(mem_flags::mem_threadgroup);
1936
1937        // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1938        for (uint k_sub = 0; k_sub < 4; k_sub++) {
1939            simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1940            simdgroup_load(A_top,
1941                t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1942                32,
1943                ulong2(0, 0),
1944                false);
1945            simdgroup_load(A_bot,
1946                t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1947                32,
1948                ulong2(0, 0),
1949                false);
1950            simdgroup_load(B_lo,
1951                w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1952                32,
1953                ulong2(0, 0),
1954                true);
1955            simdgroup_load(B_hi,
1956                w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1957                32,
1958                ulong2(0, 0),
1959                true);
1960            simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1961            simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1962            simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1963            simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1964        }
1965
1966        threadgroup_barrier(mem_flags::mem_threadgroup);
1967    }
1968
1969    // Store 4 output tiles.  Full-tile fast path assumes full 16×16 valid.
1970    uint out_tok_top = tok_base + sg_tok_base + 0;
1971    uint out_tok_bot = tok_base + sg_tok_base + 8;
1972    uint out_row_lo  = row_base + sg_row_base + 0;
1973    uint out_row_hi  = row_base + sg_row_base + 8;
1974    bool full = (out_tok_bot + 8 <= num_tokens);
1975    if (full) {
1976        simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1977        simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1978        simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1979        simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1980    } else {
1981        // Partial-token fallback via per-simdgroup scratch.
1982        threadgroup float scratch[4 * 4 * 64];  // 4 simdgroups × 4 accs × 64
1983        uint sg_base = simd_id * 256;
1984        simdgroup_store(C_00, scratch + sg_base +   0, 8);
1985        simdgroup_store(C_01, scratch + sg_base +  64, 8);
1986        simdgroup_store(C_10, scratch + sg_base + 128, 8);
1987        simdgroup_store(C_11, scratch + sg_base + 192, 8);
1988        simdgroup_barrier(mem_flags::mem_threadgroup);
1989        uint lane = tid % 32;
1990        if (lane == 0) {
1991            for (uint m = 0; m < 8; m++) {
1992                uint t_top = out_tok_top + m;
1993                if (t_top < num_tokens) {
1994                    device float* dst0 = outputs + t_top * rows + out_row_lo;
1995                    device float* dst1 = outputs + t_top * rows + out_row_hi;
1996                    threadgroup const float* src0 = scratch + sg_base +   0 + m * 8;
1997                    threadgroup const float* src1 = scratch + sg_base +  64 + m * 8;
1998                    for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
1999                }
2000                uint t_bot = out_tok_bot + m;
2001                if (t_bot < num_tokens) {
2002                    device float* dst2 = outputs + t_bot * rows + out_row_lo;
2003                    device float* dst3 = outputs + t_bot * rows + out_row_hi;
2004                    threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2005                    threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2006                    for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2007                }
2008            }
2009        }
2010    }
2011}
2012
2013// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2014// Batched Q4_0 matrix-vector multiply for M input vectors.
2015// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2016kernel void matmul_vec_q4_batch(
2017    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
2018    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
2019    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
2020    constant uint& num_tokens    [[buffer(3)]],  // M
2021    constant uint& rows          [[buffer(4)]],
2022    constant uint& cols          [[buffer(5)]],
2023    uint tgid [[threadgroup_position_in_grid]],
2024    uint tid [[thread_index_in_threadgroup]],
2025    uint simd_lane [[thread_index_in_simdgroup]],
2026    uint simd_id [[simdgroup_index_in_threadgroup]])
2027{
2028    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2029    uint token = tgid / row_tgs;
2030    uint tg_in_token = tgid % row_tgs;
2031    if (token >= num_tokens) return;
2032
2033    threadgroup float vec_tile[VEC_TILE_SIZE];
2034    device const float* input = inputs + token * cols;
2035    for (uint i = tid; i < cols; i += 256) {
2036        vec_tile[i] = input[i];
2037    }
2038    threadgroup_barrier(mem_flags::mem_threadgroup);
2039
2040    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2041    if (row_base >= rows) return;
2042
2043    uint blocks_per_row = cols / 32;
2044    uint row_bytes = blocks_per_row * 18;
2045
2046    device const uchar* r0 = matrix + row_base * row_bytes;
2047    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2048    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2049    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2050
2051    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2052
2053    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2054        uint bb = blk * 18;
2055        uint vb = blk * 32;
2056
2057        float sc0 = float(*(device const half*)(r0 + bb));
2058        float sc1 = float(*(device const half*)(r1 + bb));
2059        float sc2 = float(*(device const half*)(r2 + bb));
2060        float sc3 = float(*(device const half*)(r3 + bb));
2061
2062        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2063        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2064        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2065        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2066
2067        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2068        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2069        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2070        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2071        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2072        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2073        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2074        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2075
2076        float bd0=0, bd1=0, bd2=0, bd3=0;
2077        uchar4 b;
2078
2079        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2080        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2081        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2082        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2083
2084        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2085        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2086        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2087        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2088
2089        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2090        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2091        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2092        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2093
2094        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2095        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2096        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2097        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2098
2099        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2100    }
2101
2102    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2103    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2104
2105    device float* output = outputs + token * rows;
2106    if (simd_lane == 0) {
2107        if (row_base     < rows) output[row_base]     = sum0;
2108        if (row_base + 1 < rows) output[row_base + 1] = sum1;
2109        if (row_base + 2 < rows) output[row_base + 2] = sum2;
2110        if (row_base + 3 < rows) output[row_base + 3] = sum3;
2111    }
2112}
2113
2114// ── copy_kv_batch ─────────────────────────────────────────────────────
2115// Copy K or V from a strided batch QKV buffer to the KV cache.
2116// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2117// dst layout: contiguous [max_seq, kv_dim] cache.
2118kernel void copy_kv_batch(
2119    device const float* src  [[buffer(0)]],  // batch QKV buffer
2120    device float* dst        [[buffer(1)]],  // KV cache
2121    constant uint& M         [[buffer(2)]],  // num tokens in batch
2122    constant uint& kv_dim    [[buffer(3)]],  // floats per KV vector
2123    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
2124    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2125    constant uint& src_offset [[buffer(6)]], // float offset within each src row
2126    uint id [[thread_position_in_grid]])
2127{
2128    uint total = M * kv_dim;
2129    if (id >= total) return;
2130    uint token = id / kv_dim;
2131    uint d = id % kv_dim;
2132    uint dst_off = (base_pos + token) * kv_dim + d;
2133    uint src_off = token * src_stride + src_offset + d;
2134    dst[dst_off] = src[src_off];
2135}
2136
2137// ── attention_batch ───────────────────────────────────────────────────
2138// Batched causal attention for prefill. Processes M tokens in one dispatch.
2139// Each threadgroup handles one (token, head) pair.
2140// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2141// Causal masking: token i can only attend to positions 0..base_pos+i.
2142kernel void attention_batch(
2143    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
2144    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
2145    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
2146    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
2147    constant uint& M                 [[buffer(4)]],  // num tokens in batch
2148    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
2149    constant uint& num_heads         [[buffer(6)]],
2150    constant uint& num_kv_heads      [[buffer(7)]],
2151    constant uint& head_dim          [[buffer(8)]],
2152    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
2153    uint tgid [[threadgroup_position_in_grid]],
2154    uint tid [[thread_index_in_threadgroup]],
2155    uint simd_lane [[thread_index_in_simdgroup]],
2156    uint simd_id [[simdgroup_index_in_threadgroup]])
2157{
2158    // Grid: M * num_heads threadgroups
2159    uint token_idx = tgid / num_heads;
2160    uint head = tgid % num_heads;
2161    if (token_idx >= M) return;
2162
2163    uint kv_head = head / (num_heads / num_kv_heads);
2164    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
2165
2166    // Q offset uses strided layout (from batch QKV buffer)
2167    uint q_off = token_idx * q_stride + head * head_dim;
2168    // Output is contiguous [M, num_heads * head_dim]
2169    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2170
2171    // Shared memory for attention scores
2172    threadgroup float scores[2048];
2173
2174    // Step 1: Q * K^T with simdgroup reduction
2175    for (uint s = simd_id; s < seq_len; s += 8) {
2176        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2177        float dot = 0.0;
2178        for (uint d = simd_lane; d < head_dim; d += 32) {
2179            dot += q_batch[q_off + d] * k_cache[k_off + d];
2180        }
2181        dot = simd_sum(dot);
2182        if (simd_lane == 0) {
2183            scores[s] = dot * fast::rsqrt(float(head_dim));
2184        }
2185    }
2186    threadgroup_barrier(mem_flags::mem_threadgroup);
2187
2188    // Step 2: Softmax (cooperative)
2189    float local_max = -INFINITY;
2190    for (uint s = tid; s < seq_len; s += 256) {
2191        local_max = max(local_max, scores[s]);
2192    }
2193    local_max = simd_max(local_max);
2194    threadgroup float shared_max[8];
2195    if (simd_lane == 0) shared_max[simd_id] = local_max;
2196    threadgroup_barrier(mem_flags::mem_threadgroup);
2197    if (tid == 0) {
2198        float m = shared_max[0];
2199        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2200        shared_max[0] = m;
2201    }
2202    threadgroup_barrier(mem_flags::mem_threadgroup);
2203    float max_val = shared_max[0];
2204
2205    float local_sum = 0.0;
2206    for (uint s = tid; s < seq_len; s += 256) {
2207        scores[s] = fast::exp(scores[s] - max_val);
2208        local_sum += scores[s];
2209    }
2210    local_sum = simd_sum(local_sum);
2211    threadgroup float shared_sum[8];
2212    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2213    threadgroup_barrier(mem_flags::mem_threadgroup);
2214    if (tid == 0) {
2215        float total = 0.0;
2216        for (uint i = 0; i < 8; i++) total += shared_sum[i];
2217        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2218    }
2219    threadgroup_barrier(mem_flags::mem_threadgroup);
2220    float inv_sum = shared_sum[0];
2221    for (uint s = tid; s < seq_len; s += 256) {
2222        scores[s] *= inv_sum;
2223    }
2224    threadgroup_barrier(mem_flags::mem_threadgroup);
2225
2226    // Step 3: scores * V using float4 vectorized loads
2227    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2228    // This is much better than the scalar version where only 64 of 256 threads are active.
2229    uint v_stride = num_kv_heads * head_dim;
2230    uint head_dim4 = head_dim / 4;
2231    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2232        uint d = d4 * 4;
2233        float4 acc = float4(0.0);
2234        uint v_base = kv_head * head_dim + d;
2235        uint seq_len4 = seq_len & ~3u;
2236        for (uint s = 0; s < seq_len4; s += 4) {
2237            float sc0 = scores[s];
2238            float sc1 = scores[s + 1];
2239            float sc2 = scores[s + 2];
2240            float sc3 = scores[s + 3];
2241            acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
2242                 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
2243                 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
2244                 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
2245        }
2246        for (uint s = seq_len4; s < seq_len; s++) {
2247            acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
2248        }
2249        *(device float4*)(output_batch + out_off + d) = acc;
2250    }
2251    // Handle remaining dimensions not divisible by 4 (scalar fallback)
2252    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2253        float acc = 0.0;
2254        uint v_base = kv_head * head_dim + d;
2255        for (uint s = 0; s < seq_len; s++) {
2256            acc += scores[s] * v_cache[s * v_stride + v_base];
2257        }
2258        output_batch[out_off + d] = acc;
2259    }
2260}
2261
2262// ── rope_qk_batch ─────────────────────────────────────────────────────
2263// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2264// launch + memory barrier per layer. Both Q and K live in the same
2265// qkv_data buffer at different offsets within each token's row.
2266// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2267kernel void rope_qk_batch(
2268    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
2269    constant uint& M                 [[buffer(1)]],   // num tokens
2270    constant uint& base_pos          [[buffer(2)]],   // starting position
2271    constant uint& num_q_heads       [[buffer(3)]],
2272    constant uint& num_kv_heads      [[buffer(4)]],
2273    constant uint& head_dim          [[buffer(5)]],
2274    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
2275    constant float& theta            [[buffer(7)]],
2276    uint id [[thread_position_in_grid]])
2277{
2278    uint half_dim = head_dim / 2;
2279    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2280    uint token = id / total_pairs;
2281    uint pair = id % total_pairs;
2282    if (token >= M) return;
2283
2284    uint pos = base_pos + token;
2285    uint q_pairs = num_q_heads * half_dim;
2286
2287    uint h, i, offset;
2288    if (pair < q_pairs) {
2289        // Q head
2290        h = pair / half_dim;
2291        i = pair % half_dim;
2292        offset = token * qkv_stride + h * head_dim + i * 2;
2293    } else {
2294        // K head
2295        uint kp = pair - q_pairs;
2296        h = kp / half_dim;
2297        i = kp % half_dim;
2298        uint k_start = num_q_heads * head_dim;
2299        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2300    }
2301
2302    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2303    float angle = float(pos) * freq;
2304    float cos_val = cos(angle);
2305    float sin_val = sin(angle);
2306
2307    float x0 = qkv_data[offset];
2308    float x1 = qkv_data[offset + 1];
2309    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
2310    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2311}
2312
2313// ── copy_kv_both_batch ────────────────────────────────────────────────
2314// Fused K+V cache copy in a single dispatch: copies both K and V from
2315// the strided batch QKV buffer to their respective KV cache buffers.
2316// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2317kernel void copy_kv_both_batch(
2318    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride]
2319    device float* k_dst        [[buffer(1)]],  // K cache [max_seq, kv_dim]
2320    device float* v_dst        [[buffer(2)]],  // V cache [max_seq, kv_dim]
2321    constant uint& M           [[buffer(3)]],  // num tokens in batch
2322    constant uint& kv_dim      [[buffer(4)]],  // floats per KV vector
2323    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
2324    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
2325    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
2326    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
2327    uint id [[thread_position_in_grid]])
2328{
2329    // Total elements = M * kv_dim * 2 (K + V)
2330    uint total_kv = M * kv_dim;
2331    if (id >= total_kv * 2) return;
2332
2333    uint is_v = id / total_kv;        // 0 = K, 1 = V
2334    uint local_id = id % total_kv;
2335    uint token = local_id / kv_dim;
2336    uint d = local_id % kv_dim;
2337
2338    uint dst_off = (base_pos + token) * kv_dim + d;
2339    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2340
2341    if (is_v) {
2342        v_dst[dst_off] = src[src_off];
2343    } else {
2344        k_dst[dst_off] = src[src_off];
2345    }
2346}
2347"#
2348    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2349}
2350
2351// ---------------------------------------------------------------------------
2352// model.rs generation
2353// ---------------------------------------------------------------------------
2354
2355fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2356    let mut code = String::with_capacity(48 * 1024);
2357    emit_model_header(&mut code, config)?;
2358    emit_metal_model_struct(&mut code, config)?;
2359    emit_layer_buffers_struct(&mut code)?;
2360    emit_metal_model_impl(&mut code, config)?;
2361    emit_helper_functions(&mut code)?;
2362    Ok(code)
2363}
2364
2365fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2366    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2367    writeln!(
2368        code,
2369        "//! Model: {} ({} layers, hidden={})",
2370        config.architecture, config.num_layers, config.hidden_size
2371    )?;
2372    writeln!(code, "//!")?;
2373    writeln!(
2374        code,
2375        "//! Uses native Metal compute pipelines via the metal crate."
2376    )?;
2377    writeln!(code)?;
2378    writeln!(code, "#![allow(dead_code)]")?;
2379    writeln!(code)?;
2380    writeln!(code, "use metal::*;")?;
2381    writeln!(code, "#[allow(unused_imports)]")?;
2382    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2383    writeln!(code, "use std::mem;")?;
2384    writeln!(code)?;
2385
2386    // Model constants
2387    writeln!(
2388        code,
2389        "// ── Model constants ──────────────────────────────────"
2390    )?;
2391    writeln!(
2392        code,
2393        "pub const HIDDEN_SIZE: usize = {};",
2394        config.hidden_size
2395    )?;
2396    writeln!(
2397        code,
2398        "pub const INTERMEDIATE_SIZE: usize = {};",
2399        config.intermediate_size
2400    )?;
2401    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2402    writeln!(
2403        code,
2404        "pub const NUM_HEADS: usize = {};",
2405        config.num_attention_heads
2406    )?;
2407    writeln!(
2408        code,
2409        "pub const NUM_KV_HEADS: usize = {};",
2410        config.num_kv_heads
2411    )?;
2412    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
2413    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
2414    let effective_seq_len = config.max_seq_len.min(4096);
2415    writeln!(
2416        code,
2417        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
2418        effective_seq_len, config.max_seq_len
2419    )?;
2420    writeln!(
2421        code,
2422        "pub const RMS_NORM_EPS: f32 = {:e};",
2423        config.rms_norm_eps
2424    )?;
2425    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
2426    writeln!(
2427        code,
2428        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
2429    )?;
2430    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
2431    writeln!(code)?;
2432
2433    Ok(())
2434}
2435
2436fn emit_metal_model_struct(
2437    code: &mut String,
2438    _config: &ModelConfig,
2439) -> Result<(), MetalCodegenError> {
2440    writeln!(
2441        code,
2442        "// ── MetalModel ──────────────────────────────────────────"
2443    )?;
2444    writeln!(code)?;
2445    writeln!(
2446        code,
2447        "/// Metal-accelerated transformer model for Apple Silicon."
2448    )?;
2449    writeln!(code, "///")?;
2450    writeln!(
2451        code,
2452        "/// Uses unified memory for zero-copy weight access and native Metal"
2453    )?;
2454    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
2455    writeln!(code, "pub struct MetalModel {{")?;
2456    writeln!(code, "    device: Device,")?;
2457    writeln!(code, "    queue: CommandQueue,")?;
2458    writeln!(code)?;
2459    writeln!(code, "    // ── Compute pipelines ──")?;
2460    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
2461    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
2462    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
2463    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
2464    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
2465    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
2466    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
2467    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
2468    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
2469    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
2470    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
2471    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
2472    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
2473    writeln!(code)?;
2474    writeln!(code, "    // ── Batched prefill pipelines ──")?;
2475    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
2476    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
2477    writeln!(
2478        code,
2479        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
2480    )?;
2481    writeln!(
2482        code,
2483        "    matmul_q8_mma_pipeline: ComputePipelineState,"
2484    )?;
2485    writeln!(
2486        code,
2487        "    matmul_q8_mma32_pipeline: ComputePipelineState,"
2488    )?;
2489    writeln!(
2490        code,
2491        "    matmul_q8_mma32_h_pipeline: ComputePipelineState,"
2492    )?;
2493    writeln!(
2494        code,
2495        "    matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
2496    )?;
2497    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
2498    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
2499    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
2500    writeln!(
2501        code,
2502        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
2503    )?;
2504    writeln!(
2505        code,
2506        "    add_inplace_batch_pipeline: ComputePipelineState,"
2507    )?;
2508    writeln!(
2509        code,
2510        "    copy_embedding_batch_pipeline: ComputePipelineState,"
2511    )?;
2512    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
2513    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
2514    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
2515    writeln!(
2516        code,
2517        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
2518    )?;
2519    writeln!(code)?;
2520    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
2521    writeln!(
2522        code,
2523        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
2524    )?;
2525    writeln!(code, "    embed_buf: Buffer,")?;
2526    writeln!(code)?;
2527    writeln!(code, "    /// Per-layer weight buffers")?;
2528    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
2529    writeln!(code)?;
2530    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
2531    writeln!(code, "    norm_buf: Buffer,")?;
2532    writeln!(code)?;
2533    writeln!(
2534        code,
2535        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
2536    )?;
2537    writeln!(code, "    lm_head_buf: Buffer,")?;
2538    writeln!(code)?;
2539    writeln!(
2540        code,
2541        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
2542    )?;
2543    writeln!(code, "    hidden_buf: Buffer,")?;
2544    writeln!(code, "    residual_buf: Buffer,")?;
2545    writeln!(code, "    normed_buf: Buffer,")?;
2546    writeln!(
2547        code,
2548        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
2549    )?;
2550    writeln!(code, "    qkv_buf: Buffer,")?;
2551    writeln!(code, "    attn_out_buf: Buffer,")?;
2552    writeln!(code, "    attn_proj_buf: Buffer,")?;
2553    writeln!(
2554        code,
2555        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
2556    )?;
2557    writeln!(code, "    gate_up_buf: Buffer,")?;
2558    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
2559    writeln!(code, "    ffn_out_buf: Buffer,")?;
2560    writeln!(code, "    add_tmp_buf: Buffer,")?;
2561    writeln!(code, "    logits_buf: Buffer,")?;
2562    writeln!(code)?;
2563    writeln!(code, "    // ── Batched prefill working buffers ──")?;
2564    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
2565    writeln!(code, "    batch_hidden_buf: Buffer,")?;
2566    writeln!(
2567        code,
2568        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
2569    )?;
2570    writeln!(code, "    batch_residual_buf: Buffer,")?;
2571    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
2572    writeln!(code, "    batch_qkv_buf: Buffer,")?;
2573    writeln!(
2574        code,
2575        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
2576    )?;
2577    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
2578    writeln!(
2579        code,
2580        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
2581    )?;
2582    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
2583    writeln!(
2584        code,
2585        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
2586    )?;
2587    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
2588    writeln!(
2589        code,
2590        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
2591    )?;
2592    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
2593    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
2594    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
2595    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
2596    writeln!(code, "    batch_tokens_buf: Buffer,")?;
2597    writeln!(code, "    /// Positions buffer for batch RoPE")?;
2598    writeln!(code, "    batch_positions_buf: Buffer,")?;
2599    writeln!(code)?;
2600    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
2601    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
2602    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
2603    writeln!(code)?;
2604    writeln!(code, "    // ── Inference state ──")?;
2605    writeln!(code, "    pos: usize,")?;
2606    writeln!(code)?;
2607    writeln!(
2608        code,
2609        "    /// Previous command buffer for double-buffered prefill."
2610    )?;
2611    writeln!(
2612        code,
2613        "    /// While the GPU executes token N, the CPU can encode token N+1."
2614    )?;
2615    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
2616    writeln!(code, "}}")?;
2617    writeln!(code)?;
2618
2619    Ok(())
2620}
2621
2622fn emit_layer_buffers_struct(code: &mut String) -> Result<(), MetalCodegenError> {
2623    writeln!(
2624        code,
2625        "/// Per-layer weight buffers for attention and FFN projections."
2626    )?;
2627    writeln!(code, "struct LayerBuffers {{")?;
2628    writeln!(code, "    attn_norm: Buffer,")?;
2629    writeln!(
2630        code,
2631        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
2632    )?;
2633    writeln!(code, "    qkv_weight: Buffer,")?;
2634    writeln!(code, "    o_weight: Buffer,")?;
2635    writeln!(code, "    ffn_norm: Buffer,")?;
2636    writeln!(
2637        code,
2638        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
2639    )?;
2640    writeln!(code, "    gate_up_weight: Buffer,")?;
2641    writeln!(code, "    down_weight: Buffer,")?;
2642    writeln!(code, "}}")?;
2643    writeln!(code)?;
2644
2645    Ok(())
2646}
2647
2648fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2649    let hidden = config.hidden_size;
2650    let intermediate = config.intermediate_size;
2651    let _num_layers = config.num_layers;
2652    let num_heads = config.num_attention_heads;
2653    let num_kv_heads = config.num_kv_heads;
2654    let head_dim = config.head_dim;
2655    let vocab = config.vocab_size;
2656    let effective_seq_len = config.max_seq_len.min(4096);
2657    let is_q8 = config.dtype == DType::Q8_0;
2658    let is_q4 = config.dtype == DType::Q4_0;
2659    let kv_dim = num_kv_heads * head_dim;
2660
2661    writeln!(code, "impl MetalModel {{")?;
2662
2663    // ── new() ──
2664    writeln!(
2665        code,
2666        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
2667    )?;
2668    writeln!(code, "    ///")?;
2669    writeln!(
2670        code,
2671        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
2672    )?;
2673    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
2674    writeln!(
2675        code,
2676        "        let device = Device::system_default().expect(\"no Metal device found\");"
2677    )?;
2678    writeln!(code, "        let queue = device.new_command_queue();")?;
2679    writeln!(code)?;
2680
2681    // Compile shaders
2682    writeln!(
2683        code,
2684        "        // Compile Metal shaders from embedded source"
2685    )?;
2686    writeln!(
2687        code,
2688        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
2689    )?;
2690    writeln!(
2691        code,
2692        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
2693    )?;
2694    writeln!(
2695        code,
2696        "            .expect(\"failed to compile Metal shaders\");"
2697    )?;
2698    writeln!(code)?;
2699
2700    // Create compute pipelines
2701    writeln!(code, "        // Create compute pipelines")?;
2702    for (var, fn_name) in [
2703        ("matmul_pipeline", "matmul_vec"),
2704        ("matmul_q8_pipeline", "matmul_vec_q8"),
2705        ("matmul_q4_pipeline", "matmul_vec_q4"),
2706        ("rms_norm_pipeline", "rms_norm"),
2707        ("rope_pipeline", "rope"),
2708        ("softmax_pipeline", "softmax"),
2709        ("silu_mul_pipeline", "silu_mul"),
2710        ("silu_mul_fused_pipeline", "silu_mul_fused"),
2711        ("add_pipeline", "elementwise_add"),
2712        ("attention_pipeline", "attention"),
2713        ("add_inplace_pipeline", "add_inplace"),
2714        ("copy_pipeline", "copy_buffer"),
2715        ("copy_offset_pipeline", "copy_offset"),
2716        ("matmul_batch_pipeline", "matmul_vec_batch"),
2717        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
2718        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
2719        ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
2720        ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
2721        ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
2722        ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
2723        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
2724        ("rms_norm_batch_pipeline", "rms_norm_batch"),
2725        ("rope_batch_pipeline", "rope_batch"),
2726        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
2727        ("add_inplace_batch_pipeline", "add_inplace_batch"),
2728        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
2729        ("attention_batch_pipeline", "attention_batch"),
2730        ("copy_kv_batch_pipeline", "copy_kv_batch"),
2731        ("rope_qk_batch_pipeline", "rope_qk_batch"),
2732        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
2733    ] {
2734        writeln!(
2735            code,
2736            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
2737        )?;
2738    }
2739    writeln!(code)?;
2740
2741    // Weight loading
2742    writeln!(
2743        code,
2744        "        // Load weights into Metal shared-memory buffers"
2745    )?;
2746    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
2747    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
2748    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
2749    writeln!(code)?;
2750    writeln!(
2751        code,
2752        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
2753    )?;
2754    writeln!(code)?;
2755    writeln!(
2756        code,
2757        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
2758    )?;
2759    writeln!(
2760        code,
2761        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
2762    )?;
2763    writeln!(code, "            let byte_len = n * f32_size;")?;
2764    writeln!(code, "            let cur = cursor.get();")?;
2765    writeln!(
2766        code,
2767        "            let data = &weights[cur..cur + byte_len];"
2768    )?;
2769    writeln!(code, "            cursor.set(cur + byte_len);")?;
2770    writeln!(code, "            device.new_buffer_with_data(")?;
2771    writeln!(code, "                data.as_ptr() as *const _,")?;
2772    writeln!(code, "                byte_len as u64,")?;
2773    writeln!(
2774        code,
2775        "                MTLResourceOptions::StorageModeShared,"
2776    )?;
2777    writeln!(code, "            )")?;
2778    writeln!(code, "        }};")?;
2779    writeln!(code)?;
2780
2781    if is_q8 {
2782        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
2783        // We load them directly into Metal buffers without dequantizing,
2784        // and use the matmul_vec_q8 shader that operates on quantized data.
2785        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
2786        writeln!(
2787            code,
2788            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
2789        )?;
2790        writeln!(
2791            code,
2792            "        // as raw bytes into a Metal buffer (no dequantization)."
2793        )?;
2794        writeln!(
2795            code,
2796            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
2797        )?;
2798        writeln!(
2799            code,
2800            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
2801        )?;
2802        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2803        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
2804        writeln!(code, "            let total_raw = rows * row_bytes;")?;
2805        writeln!(code, "            let cur = cursor.get();")?;
2806        writeln!(
2807            code,
2808            "            let data = &weights[cur..cur + total_raw];"
2809        )?;
2810        writeln!(code, "            cursor.set(cur + total_raw);")?;
2811        writeln!(code, "            device.new_buffer_with_data(")?;
2812        writeln!(code, "                data.as_ptr() as *const _,")?;
2813        writeln!(code, "                total_raw as u64,")?;
2814        writeln!(
2815            code,
2816            "                MTLResourceOptions::StorageModeShared,"
2817        )?;
2818        writeln!(code, "            )")?;
2819        writeln!(code, "        }};")?;
2820        writeln!(code)?;
2821        writeln!(
2822            code,
2823            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
2824        )?;
2825        writeln!(
2826            code,
2827            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
2828        )?;
2829        writeln!(
2830            code,
2831            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
2832        )?;
2833        writeln!(
2834            code,
2835            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
2836        )?;
2837        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2838        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
2839        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
2840        writeln!(code, "            let cur = cursor.get();")?;
2841        writeln!(
2842            code,
2843            "            let data = &weights[cur..cur + total_raw];"
2844        )?;
2845        writeln!(code, "            cursor.set(cur + total_raw);")?;
2846        writeln!(code, "            device.new_buffer_with_data(")?;
2847        writeln!(code, "                data.as_ptr() as *const _,")?;
2848        writeln!(code, "                total_raw as u64,")?;
2849        writeln!(
2850            code,
2851            "                MTLResourceOptions::StorageModeShared,"
2852        )?;
2853        writeln!(code, "            )")?;
2854        writeln!(code, "        }};")?;
2855        writeln!(code)?;
2856    }
2857
2858    if is_q4 {
2859        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
2860        // We load them directly into Metal buffers without dequantizing,
2861        // and use the matmul_vec_q4 shader that operates on quantized data.
2862        // This quarters GPU memory usage vs f32 dequantization.
2863        writeln!(
2864            code,
2865            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
2866        )?;
2867        writeln!(
2868            code,
2869            "        // as raw bytes into a Metal buffer (no dequantization)."
2870        )?;
2871        writeln!(
2872            code,
2873            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
2874        )?;
2875        writeln!(
2876            code,
2877            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
2878        )?;
2879        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2880        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
2881        writeln!(code, "            let total_raw = rows * row_bytes;")?;
2882        writeln!(code, "            let cur = cursor.get();")?;
2883        writeln!(
2884            code,
2885            "            let data = &weights[cur..cur + total_raw];"
2886        )?;
2887        writeln!(code, "            cursor.set(cur + total_raw);")?;
2888        writeln!(code, "            device.new_buffer_with_data(")?;
2889        writeln!(code, "                data.as_ptr() as *const _,")?;
2890        writeln!(code, "                total_raw as u64,")?;
2891        writeln!(
2892            code,
2893            "                MTLResourceOptions::StorageModeShared,"
2894        )?;
2895        writeln!(code, "            )")?;
2896        writeln!(code, "        }};")?;
2897        writeln!(code)?;
2898        writeln!(
2899            code,
2900            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
2901        )?;
2902        writeln!(
2903            code,
2904            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
2905        )?;
2906        writeln!(
2907            code,
2908            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
2909        )?;
2910        writeln!(
2911            code,
2912            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
2913        )?;
2914        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2915        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
2916        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
2917        writeln!(code, "            let cur = cursor.get();")?;
2918        writeln!(
2919            code,
2920            "            let data = &weights[cur..cur + total_raw];"
2921        )?;
2922        writeln!(code, "            cursor.set(cur + total_raw);")?;
2923        writeln!(code, "            device.new_buffer_with_data(")?;
2924        writeln!(code, "                data.as_ptr() as *const _,")?;
2925        writeln!(code, "                total_raw as u64,")?;
2926        writeln!(
2927            code,
2928            "                MTLResourceOptions::StorageModeShared,"
2929        )?;
2930        writeln!(code, "            )")?;
2931        writeln!(code, "        }};")?;
2932        writeln!(code)?;
2933    }
2934
2935    writeln!(
2936        code,
2937        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
2938    )?;
2939    writeln!(code)?;
2940
2941    // Per-layer weights
2942    writeln!(
2943        code,
2944        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
2945    )?;
2946    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
2947
2948    // attn_norm is always f32
2949    writeln!(
2950        code,
2951        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
2952    )?;
2953
2954    let qkv_rows = hidden + 2 * kv_dim;
2955    if is_q8 {
2956        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
2957        writeln!(
2958            code,
2959            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
2960        )?;
2961        writeln!(
2962            code,
2963            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
2964        )?;
2965    } else if is_q4 {
2966        // Fused Q+K+V weight: read all three consecutive Q4_0 matrices as one buffer
2967        writeln!(
2968            code,
2969            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
2970        )?;
2971        writeln!(
2972            code,
2973            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
2974        )?;
2975    } else {
2976        // Fused Q+K+V weight: read all three as a single contiguous f32 buffer
2977        writeln!(
2978            code,
2979            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
2980        )?;
2981        writeln!(
2982            code,
2983            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
2984        )?;
2985    }
2986
2987    // ffn_norm is always f32
2988    writeln!(
2989        code,
2990        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
2991    )?;
2992
2993    let gate_up_rows = 2 * intermediate;
2994    if is_q8 {
2995        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
2996        writeln!(
2997            code,
2998            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
2999        )?;
3000        writeln!(
3001            code,
3002            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3003        )?;
3004    } else if is_q4 {
3005        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
3006        writeln!(
3007            code,
3008            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3009        )?;
3010        writeln!(
3011            code,
3012            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3013        )?;
3014    } else {
3015        // Fused gate+up weight: read both as a single contiguous f32 buffer
3016        writeln!(
3017            code,
3018            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3019        )?;
3020        writeln!(
3021            code,
3022            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3023        )?;
3024    }
3025
3026    writeln!(code, "            layers.push(LayerBuffers {{")?;
3027    writeln!(code, "                attn_norm,")?;
3028    writeln!(code, "                qkv_weight,")?;
3029    writeln!(code, "                o_weight,")?;
3030    writeln!(code, "                ffn_norm,")?;
3031    writeln!(code, "                gate_up_weight,")?;
3032    writeln!(code, "                down_weight,")?;
3033    writeln!(code, "            }});")?;
3034    writeln!(code, "        }}")?;
3035    writeln!(code)?;
3036
3037    // final_norm is always f32
3038    writeln!(
3039        code,
3040        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
3041    )?;
3042    writeln!(code)?;
3043
3044    // lm_head
3045    if is_q8 {
3046        writeln!(
3047            code,
3048            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3049        )?;
3050    } else if is_q4 {
3051        writeln!(
3052            code,
3053            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3054        )?;
3055    } else {
3056        writeln!(
3057            code,
3058            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3059        )?;
3060    }
3061    writeln!(code)?;
3062
3063    // Working buffers
3064    let hidden_bytes = hidden * 4;
3065    let _kv_dim_bytes = kv_dim * 4;
3066    let intermediate_bytes = intermediate * 4;
3067    let vocab_bytes = vocab * 4;
3068    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
3069
3070    writeln!(
3071        code,
3072        "        // Allocate working buffers (shared memory for zero-copy)"
3073    )?;
3074    writeln!(
3075        code,
3076        "        let opts = MTLResourceOptions::StorageModeShared;"
3077    )?;
3078    writeln!(
3079        code,
3080        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3081    )?;
3082    writeln!(
3083        code,
3084        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3085    )?;
3086    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3087    writeln!(
3088        code,
3089        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3090    )?;
3091    writeln!(
3092        code,
3093        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3094    )?;
3095    writeln!(
3096        code,
3097        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3098    )?;
3099    writeln!(
3100        code,
3101        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3102    )?;
3103    writeln!(
3104        code,
3105        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3106    )?;
3107    let gate_up_buf_bytes = 2 * intermediate * 4;
3108    writeln!(
3109        code,
3110        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3111    )?;
3112    writeln!(
3113        code,
3114        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3115    )?;
3116    writeln!(
3117        code,
3118        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3119    )?;
3120    writeln!(
3121        code,
3122        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3123    )?;
3124    writeln!(
3125        code,
3126        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3127    )?;
3128    writeln!(
3129        code,
3130        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3131    )?;
3132    writeln!(code)?;
3133
3134    // Batch prefill working buffers
3135    let batch_hidden_bytes = hidden * 4; // per-token
3136    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3137    let batch_gate_up_bytes = 2 * intermediate * 4;
3138    let batch_intermediate_bytes = intermediate * 4;
3139    writeln!(
3140        code,
3141        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3142    )?;
3143    writeln!(
3144        code,
3145        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3146    )?;
3147    writeln!(
3148        code,
3149        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3150    )?;
3151    writeln!(
3152        code,
3153        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3154    )?;
3155    writeln!(
3156        code,
3157        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3158    )?;
3159    writeln!(
3160        code,
3161        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3162    )?;
3163    writeln!(
3164        code,
3165        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3166    )?;
3167    writeln!(
3168        code,
3169        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3170    )?;
3171    writeln!(
3172        code,
3173        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3174    )?;
3175    writeln!(
3176        code,
3177        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3178    )?;
3179    writeln!(
3180        code,
3181        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3182    )?;
3183    writeln!(code)?;
3184
3185    // KV cache buffers
3186    writeln!(code, "        // KV cache buffers (per-layer)")?;
3187    writeln!(
3188        code,
3189        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3190    )?;
3191    writeln!(
3192        code,
3193        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3194    )?;
3195    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
3196    writeln!(
3197        code,
3198        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3199    )?;
3200    writeln!(
3201        code,
3202        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3203    )?;
3204    writeln!(code, "        }}")?;
3205    writeln!(code)?;
3206
3207    writeln!(code, "        Self {{")?;
3208    writeln!(code, "            device,")?;
3209    writeln!(code, "            queue,")?;
3210    writeln!(code, "            matmul_pipeline,")?;
3211    writeln!(code, "            matmul_q8_pipeline,")?;
3212    writeln!(code, "            matmul_q4_pipeline,")?;
3213    writeln!(code, "            rms_norm_pipeline,")?;
3214    writeln!(code, "            rope_pipeline,")?;
3215    writeln!(code, "            softmax_pipeline,")?;
3216    writeln!(code, "            silu_mul_pipeline,")?;
3217    writeln!(code, "            silu_mul_fused_pipeline,")?;
3218    writeln!(code, "            add_pipeline,")?;
3219    writeln!(code, "            attention_pipeline,")?;
3220    writeln!(code, "            add_inplace_pipeline,")?;
3221    writeln!(code, "            copy_pipeline,")?;
3222    writeln!(code, "            copy_offset_pipeline,")?;
3223    writeln!(code, "            matmul_batch_pipeline,")?;
3224    writeln!(code, "            matmul_q8_batch_pipeline,")?;
3225    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
3226    writeln!(code, "            matmul_q8_mma_pipeline,")?;
3227    writeln!(code, "            matmul_q8_mma32_pipeline,")?;
3228    writeln!(code, "            matmul_q8_mma32_h_pipeline,")?;
3229    writeln!(code, "            matmul_q8_mma32_h4_pipeline,")?;
3230    writeln!(code, "            matmul_q4_batch_pipeline,")?;
3231    writeln!(code, "            rms_norm_batch_pipeline,")?;
3232    writeln!(code, "            rope_batch_pipeline,")?;
3233    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
3234    writeln!(code, "            add_inplace_batch_pipeline,")?;
3235    writeln!(code, "            copy_embedding_batch_pipeline,")?;
3236    writeln!(code, "            attention_batch_pipeline,")?;
3237    writeln!(code, "            copy_kv_batch_pipeline,")?;
3238    writeln!(code, "            rope_qk_batch_pipeline,")?;
3239    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
3240    writeln!(code, "            embed_buf,")?;
3241    writeln!(code, "            layers,")?;
3242    writeln!(code, "            norm_buf,")?;
3243    writeln!(code, "            lm_head_buf,")?;
3244    writeln!(code, "            hidden_buf,")?;
3245    writeln!(code, "            residual_buf,")?;
3246    writeln!(code, "            normed_buf,")?;
3247    writeln!(code, "            qkv_buf,")?;
3248    writeln!(code, "            attn_out_buf,")?;
3249    writeln!(code, "            attn_proj_buf,")?;
3250    writeln!(code, "            gate_up_buf,")?;
3251    writeln!(code, "            ffn_hidden_buf,")?;
3252    writeln!(code, "            ffn_out_buf,")?;
3253    writeln!(code, "            add_tmp_buf,")?;
3254    writeln!(code, "            logits_buf,")?;
3255    writeln!(code, "            batch_hidden_buf,")?;
3256    writeln!(code, "            batch_residual_buf,")?;
3257    writeln!(code, "            batch_qkv_buf,")?;
3258    writeln!(code, "            batch_attn_out_buf,")?;
3259    writeln!(code, "            batch_attn_proj_buf,")?;
3260    writeln!(code, "            batch_gate_up_buf,")?;
3261    writeln!(code, "            batch_ffn_hidden_buf,")?;
3262    writeln!(code, "            batch_ffn_out_buf,")?;
3263    writeln!(code, "            batch_tokens_buf,")?;
3264    writeln!(code, "            batch_positions_buf,")?;
3265    writeln!(code, "            k_cache,")?;
3266    writeln!(code, "            v_cache,")?;
3267    writeln!(code, "            pos: 0,")?;
3268    writeln!(code, "            prev_cmd: None,")?;
3269    writeln!(code, "        }}")?;
3270    writeln!(code, "    }}")?;
3271    writeln!(code)?;
3272
3273    // ── forward() ──
3274    writeln!(
3275        code,
3276        "    /// Run the forward pass for a single token at the current position."
3277    )?;
3278    writeln!(code, "    ///")?;
3279    writeln!(
3280        code,
3281        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
3282    )?;
3283    writeln!(code, "    ///")?;
3284    writeln!(
3285        code,
3286        "    /// All GPU operations are encoded into a single command buffer and"
3287    )?;
3288    writeln!(
3289        code,
3290        "    /// committed once at the end, avoiding per-operation synchronization."
3291    )?;
3292    writeln!(
3293        code,
3294        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3295    )?;
3296    writeln!(
3297        code,
3298        "        // Wait for any pending prefill command buffer"
3299    )?;
3300    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3301    writeln!(code, "            prev.wait_until_completed();")?;
3302    writeln!(code, "        }}")?;
3303    writeln!(code)?;
3304    writeln!(code, "        let pos = self.pos;")?;
3305    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3306    writeln!(code)?;
3307
3308    // Single compute encoder for the entire forward pass — no blit encoder
3309    // transitions. Copy operations use compute copy kernels instead of blits.
3310    let matmul_fn = if is_q8 {
3311        "dispatch_matmul_q8"
3312    } else if is_q4 {
3313        "dispatch_matmul_q4"
3314    } else {
3315        "dispatch_matmul"
3316    };
3317
3318    writeln!(
3319        code,
3320        "        // Single compute encoder for the entire forward pass (no blit transitions)"
3321    )?;
3322    writeln!(code, "        {{")?;
3323    writeln!(
3324        code,
3325        "            let enc = cmd.new_compute_command_encoder();"
3326    )?;
3327    writeln!(code)?;
3328
3329    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
3330    writeln!(
3331        code,
3332        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3333    )?;
3334    writeln!(
3335        code,
3336        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3337    )?;
3338    writeln!(
3339        code,
3340        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3341    )?;
3342    writeln!(
3343        code,
3344        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
3345        hidden * 4,
3346    )?;
3347    writeln!(code, "            unsafe {{")?;
3348    writeln!(
3349        code,
3350        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
3351    )?;
3352    writeln!(
3353        code,
3354        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3355    )?;
3356    writeln!(
3357        code,
3358        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
3359    )?;
3360    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
3361    writeln!(
3362        code,
3363        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3364    )?;
3365    writeln!(code, "                    hidden_ptr,")?;
3366    writeln!(code, "                    HIDDEN_SIZE,")?;
3367    writeln!(code, "                );")?;
3368    writeln!(
3369        code,
3370        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3371    )?;
3372    writeln!(code, "            }}")?;
3373    writeln!(code)?;
3374
3375    // 2. Transformer layers
3376    writeln!(code, "            // 2. Transformer layers")?;
3377    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3378    writeln!(code)?;
3379    let q_byte_offset = 0usize;
3380    let k_byte_offset = hidden * 4;
3381    let v_byte_offset = (hidden + kv_dim) * 4;
3382
3383    writeln!(
3384        code,
3385        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
3386    )?;
3387    writeln!(
3388        code,
3389        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3390    )?;
3391    writeln!(
3392        code,
3393        "                // Fused Q+K+V matmul: single dispatch for all three projections"
3394    )?;
3395    writeln!(
3396        code,
3397        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3398    )?;
3399    writeln!(
3400        code,
3401        "                // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
3402    )?;
3403    writeln!(
3404        code,
3405        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3406    )?;
3407    writeln!(
3408        code,
3409        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3410    )?;
3411    writeln!(code)?;
3412    writeln!(
3413        code,
3414        "                // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
3415    )?;
3416    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
3417    writeln!(
3418        code,
3419        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3420    )?;
3421    writeln!(
3422        code,
3423        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3424    )?;
3425    writeln!(code)?;
3426    writeln!(
3427        code,
3428        "                // Attention using Q from qkv_buf (offset 0)"
3429    )?;
3430    writeln!(
3431        code,
3432        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3433    )?;
3434    writeln!(
3435        code,
3436        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3437    )?;
3438    writeln!(
3439        code,
3440        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3441    )?;
3442    writeln!(
3443        code,
3444        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
3445    )?;
3446    writeln!(
3447        code,
3448        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3449    )?;
3450    writeln!(
3451        code,
3452        "                // Fused gate+up matmul: single dispatch for both projections"
3453    )?;
3454    writeln!(
3455        code,
3456        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3457    )?;
3458    writeln!(
3459        code,
3460        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3461    )?;
3462    writeln!(
3463        code,
3464        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3465    )?;
3466    writeln!(
3467        code,
3468        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3469    )?;
3470    writeln!(code, "            }}")?;
3471    writeln!(code)?;
3472
3473    // 3. Final RMS norm + logits
3474    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
3475    writeln!(
3476        code,
3477        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
3478    )?;
3479    writeln!(
3480        code,
3481        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3482    )?;
3483    writeln!(code)?;
3484    writeln!(code, "            enc.end_encoding();")?;
3485    writeln!(code, "        }}")?;
3486    writeln!(code)?;
3487
3488    // 5. Single commit + wait, then read back logits
3489    writeln!(
3490        code,
3491        "        // 5. Commit all GPU work and wait for completion"
3492    )?;
3493    writeln!(code, "        cmd.commit();")?;
3494    writeln!(code, "        cmd.wait_until_completed();")?;
3495    writeln!(code)?;
3496    writeln!(code, "        // 6. Read back logits from GPU")?;
3497    writeln!(code, "        let logits = unsafe {{")?;
3498    writeln!(
3499        code,
3500        "            let ptr = self.logits_buf.contents() as *const f32;"
3501    )?;
3502    writeln!(
3503        code,
3504        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3505    )?;
3506    writeln!(code, "        }};")?;
3507    writeln!(code)?;
3508    writeln!(code, "        self.pos += 1;")?;
3509    writeln!(code, "        logits")?;
3510    writeln!(code, "    }}")?;
3511    writeln!(code)?;
3512
3513    // ── forward_profile: instrumented forward with per-operation timing ──
3514    writeln!(
3515        code,
3516        "    /// Profiling forward pass that prints per-stage GPU timing."
3517    )?;
3518    writeln!(code, "    ///")?;
3519    writeln!(
3520        code,
3521        "    /// Each stage is committed and waited on separately so that GPU timestamps"
3522    )?;
3523    writeln!(
3524        code,
3525        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
3526    )?;
3527    writeln!(
3528        code,
3529        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
3530    )?;
3531    writeln!(
3532        code,
3533        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
3534    )?;
3535    writeln!(code, "        use std::time::Instant;")?;
3536    writeln!(code)?;
3537    writeln!(
3538        code,
3539        "        // Wait for any pending prefill command buffer"
3540    )?;
3541    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3542    writeln!(code, "            prev.wait_until_completed();")?;
3543    writeln!(code, "        }}")?;
3544    writeln!(code)?;
3545    writeln!(code, "        let pos = self.pos;")?;
3546    writeln!(code)?;
3547
3548    // Stage: embedding (CPU, no GPU)
3549    writeln!(
3550        code,
3551        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
3552    )?;
3553    writeln!(code, "        let t_embed = Instant::now();")?;
3554    writeln!(code, "        unsafe {{")?;
3555    writeln!(
3556        code,
3557        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
3558    )?;
3559    writeln!(
3560        code,
3561        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3562    )?;
3563    writeln!(
3564        code,
3565        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
3566    )?;
3567    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
3568    writeln!(
3569        code,
3570        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3571    )?;
3572    writeln!(code, "                hidden_ptr,")?;
3573    writeln!(code, "                HIDDEN_SIZE,")?;
3574    writeln!(code, "            );")?;
3575    writeln!(
3576        code,
3577        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3578    )?;
3579    writeln!(code, "        }}")?;
3580    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
3581    writeln!(code)?;
3582
3583    // Stage: Transformer layers (all together on GPU)
3584    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
3585    writeln!(code, "        let t_layers = Instant::now();")?;
3586    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3587    writeln!(code, "        {{")?;
3588    writeln!(
3589        code,
3590        "            let enc = cmd.new_compute_command_encoder();"
3591    )?;
3592    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3593    writeln!(
3594        code,
3595        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3596    )?;
3597    writeln!(
3598        code,
3599        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3600    )?;
3601    writeln!(
3602        code,
3603        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3604    )?;
3605    writeln!(
3606        code,
3607        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3608    )?;
3609    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
3610    writeln!(
3611        code,
3612        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3613    )?;
3614    writeln!(
3615        code,
3616        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3617    )?;
3618    writeln!(
3619        code,
3620        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3621    )?;
3622    writeln!(
3623        code,
3624        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3625    )?;
3626    writeln!(
3627        code,
3628        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3629    )?;
3630    writeln!(
3631        code,
3632        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3633    )?;
3634    writeln!(
3635        code,
3636        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3637    )?;
3638    writeln!(
3639        code,
3640        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3641    )?;
3642    writeln!(
3643        code,
3644        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3645    )?;
3646    writeln!(
3647        code,
3648        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3649    )?;
3650    writeln!(code, "            }}")?;
3651    writeln!(code, "            enc.end_encoding();")?;
3652    writeln!(code, "        }}")?;
3653    writeln!(code, "        cmd.commit();")?;
3654    writeln!(code, "        cmd.wait_until_completed();")?;
3655    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
3656    writeln!(code)?;
3657
3658    // Stage: Final norm + logits
3659    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
3660    writeln!(code, "        let t_logits = Instant::now();")?;
3661    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
3662    writeln!(code, "        {{")?;
3663    writeln!(
3664        code,
3665        "            let enc = cmd2.new_compute_command_encoder();"
3666    )?;
3667    writeln!(
3668        code,
3669        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
3670    )?;
3671    writeln!(
3672        code,
3673        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3674    )?;
3675    writeln!(code, "            enc.end_encoding();")?;
3676    writeln!(code, "        }}")?;
3677    writeln!(code, "        cmd2.commit();")?;
3678    writeln!(code, "        cmd2.wait_until_completed();")?;
3679    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
3680    writeln!(code)?;
3681
3682    // Print profile results
3683    writeln!(
3684        code,
3685        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
3686    )?;
3687    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
3688    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
3689    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
3690    writeln!(
3691        code,
3692        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
3693    )?;
3694    writeln!(code)?;
3695
3696    // Read back logits
3697    writeln!(code, "        let logits = unsafe {{")?;
3698    writeln!(
3699        code,
3700        "            let ptr = self.logits_buf.contents() as *const f32;"
3701    )?;
3702    writeln!(
3703        code,
3704        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3705    )?;
3706    writeln!(code, "        }};")?;
3707    writeln!(code)?;
3708    writeln!(code, "        self.pos += 1;")?;
3709    writeln!(code, "        logits")?;
3710    writeln!(code, "    }}")?;
3711    writeln!(code)?;
3712
3713    // ── forward_prefill: single-token async forward (backward compat) ──
3714    writeln!(
3715        code,
3716        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
3717    )?;
3718    writeln!(code, "    ///")?;
3719    writeln!(
3720        code,
3721        "    /// Commits the command buffer without waiting, enabling double-buffered"
3722    )?;
3723    writeln!(
3724        code,
3725        "    /// execution: GPU processes token N while CPU encodes token N+1."
3726    )?;
3727    writeln!(
3728        code,
3729        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
3730    )?;
3731    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
3732    writeln!(code, "    }}")?;
3733    writeln!(code)?;
3734
3735    // ── forward_prefill_batch: batched prefill for multiple tokens ──
3736    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
3737    let batch_matmul_fn = if is_q8 {
3738        "dispatch_matmul_q8_batch"
3739    } else if is_q4 {
3740        "dispatch_matmul_q4_batch"
3741    } else {
3742        "dispatch_matmul_batch"
3743    };
3744
3745    writeln!(
3746        code,
3747        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
3748    )?;
3749    writeln!(code, "    ///")?;
3750    writeln!(
3751        code,
3752        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
3753    )?;
3754    writeln!(
3755        code,
3756        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
3757    )?;
3758    writeln!(
3759        code,
3760        "    /// This provides significant speedup during prompt prefill."
3761    )?;
3762    writeln!(
3763        code,
3764        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
3765    )?;
3766    writeln!(code, "        let m = tokens.len().min(MAX_BATCH_SIZE);")?;
3767    writeln!(code, "        if m == 0 {{ return; }}")?;
3768    writeln!(code, "        let start_pos = self.pos;")?;
3769    writeln!(code)?;
3770    writeln!(code, "        // Wait for any pending command buffer")?;
3771    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3772    writeln!(code, "            prev.wait_until_completed();")?;
3773    writeln!(code, "        }}")?;
3774    writeln!(code)?;
3775
3776    // Upload token IDs and positions to GPU
3777    writeln!(
3778        code,
3779        "        // Upload token IDs and positions to GPU buffers"
3780    )?;
3781    writeln!(code, "        unsafe {{")?;
3782    writeln!(
3783        code,
3784        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
3785    )?;
3786    writeln!(
3787        code,
3788        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
3789    )?;
3790    writeln!(code, "            for i in 0..m {{")?;
3791    writeln!(code, "                *tok_ptr.add(i) = tokens[i];")?;
3792    writeln!(
3793        code,
3794        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
3795    )?;
3796    writeln!(code, "            }}")?;
3797    writeln!(code, "        }}")?;
3798    writeln!(code)?;
3799
3800    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3801    writeln!(code, "        {{")?;
3802    writeln!(
3803        code,
3804        "            let enc = cmd.new_compute_command_encoder();"
3805    )?;
3806    writeln!(code)?;
3807
3808    // 1. Batch embedding lookup
3809    writeln!(
3810        code,
3811        "            // 1. Batch embedding lookup: copy all token embeddings at once"
3812    )?;
3813    writeln!(
3814        code,
3815        "            self.dispatch_copy_embedding_batch(&enc, m);"
3816    )?;
3817    // Copy batch_hidden -> batch_residual
3818    writeln!(
3819        code,
3820        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
3821    )?;
3822    writeln!(code)?;
3823
3824    // 2. Transformer layers
3825    writeln!(code, "            // 2. Transformer layers")?;
3826    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3827    writeln!(code)?;
3828
3829    // Batch RMS norm: residual -> hidden (batched)
3830    writeln!(
3831        code,
3832        "                // Batch RMS norm: batch_residual -> batch_hidden"
3833    )?;
3834    writeln!(
3835        code,
3836        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
3837    )?;
3838
3839    // Batch QKV matmul
3840    writeln!(
3841        code,
3842        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
3843    )?;
3844    writeln!(
3845        code,
3846        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
3847    )?;
3848    writeln!(code)?;
3849
3850    // Fused RoPE on Q+K portions in a single dispatch
3851    let k_float_offset = hidden;
3852    writeln!(
3853        code,
3854        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
3855    )?;
3856    writeln!(
3857        code,
3858        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
3859    )?;
3860    writeln!(code)?;
3861
3862    // Fused KV cache update: copy both K and V in a single dispatch
3863    let v_float_offset = hidden + kv_dim;
3864    writeln!(
3865        code,
3866        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
3867    )?;
3868    writeln!(
3869        code,
3870        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
3871    )?;
3872    writeln!(code)?;
3873
3874    // Batched causal attention: ONE dispatch for all M tokens
3875    writeln!(
3876        code,
3877        "                // Batched causal attention: one dispatch for all M tokens"
3878    )?;
3879    writeln!(
3880        code,
3881        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
3882    )?;
3883    writeln!(code)?;
3884
3885    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
3886    writeln!(code, "                // Batched O projection")?;
3887    writeln!(
3888        code,
3889        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
3890    )?;
3891    writeln!(code)?;
3892
3893    // Batch add: residual += attn_proj for all tokens
3894    writeln!(
3895        code,
3896        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
3897    )?;
3898    writeln!(code)?;
3899
3900    // Batch FFN
3901    writeln!(
3902        code,
3903        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
3904    )?;
3905    writeln!(
3906        code,
3907        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
3908    )?;
3909    writeln!(
3910        code,
3911        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
3912    )?;
3913    writeln!(
3914        code,
3915        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
3916    )?;
3917    writeln!(
3918        code,
3919        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
3920    )?;
3921    writeln!(
3922        code,
3923        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
3924    )?;
3925    writeln!(code, "            }}")?;
3926    writeln!(code)?;
3927
3928    // Copy last token's residual to single-token residual_buf for next forward() call
3929    writeln!(
3930        code,
3931        "            // Copy last token's residual to single-token buffer for subsequent forward()"
3932    )?;
3933    writeln!(
3934        code,
3935        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
3936    )?;
3937    writeln!(code)?;
3938    writeln!(code, "            enc.end_encoding();")?;
3939    writeln!(code, "        }}")?;
3940    writeln!(code)?;
3941
3942    writeln!(code, "        cmd.commit();")?;
3943    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
3944    writeln!(code, "        self.pos += m;")?;
3945    writeln!(code, "    }}")?;
3946    writeln!(code)?;
3947
3948    // ── reset() — rewind KV cache position for new inference requests ──
3949    writeln!(
3950        code,
3951        "    /// Reset the model state for a new inference request."
3952    )?;
3953    writeln!(code, "    pub fn reset(&mut self) {{")?;
3954    writeln!(code, "        self.pos = 0;")?;
3955    writeln!(code, "        self.prev_cmd = None;")?;
3956    writeln!(code, "    }}")?;
3957    writeln!(code)?;
3958
3959    // ── Private dispatch helpers (all take a shared compute encoder) ──
3960    writeln!(
3961        code,
3962        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
3963    )?;
3964    writeln!(
3965        code,
3966        "    // These methods set pipeline state + buffers + dispatch on an existing"
3967    )?;
3968    writeln!(
3969        code,
3970        "    // encoder, avoiding per-operation encoder creation overhead."
3971    )?;
3972    writeln!(code)?;
3973
3974    // dispatch_rms_norm
3975    writeln!(
3976        code,
3977        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
3978    )?;
3979    writeln!(
3980        code,
3981        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
3982    )?;
3983    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
3984    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
3985    writeln!(
3986        code,
3987        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
3988    )?;
3989    writeln!(
3990        code,
3991        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
3992    )?;
3993    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
3994    writeln!(
3995        code,
3996        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
3997    )?;
3998    writeln!(
3999        code,
4000        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4001    )?;
4002    writeln!(
4003        code,
4004        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4005    )?;
4006    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4007    writeln!(
4008        code,
4009        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
4010    )?;
4011    writeln!(
4012        code,
4013        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4014    )?;
4015    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4016    writeln!(code, "    }}")?;
4017    writeln!(code)?;
4018
4019    // dispatch_matmul
4020    writeln!(
4021        code,
4022        "    /// Dispatch matrix-vector multiply: weight * input -> output."
4023    )?;
4024    writeln!(
4025        code,
4026        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4027    )?;
4028    writeln!(code, "        let r: u32 = rows as u32;")?;
4029    writeln!(code, "        let c: u32 = cols as u32;")?;
4030    writeln!(
4031        code,
4032        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4033    )?;
4034    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4035    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4036    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4037    writeln!(
4038        code,
4039        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4040    )?;
4041    writeln!(
4042        code,
4043        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4044    )?;
4045    writeln!(
4046        code,
4047        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4048    )?;
4049    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4050    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
4051    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4052    writeln!(
4053        code,
4054        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4055    )?;
4056    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4057    writeln!(code, "    }}")?;
4058    writeln!(code)?;
4059
4060    // dispatch_matmul_q8
4061    writeln!(
4062        code,
4063        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4064    )?;
4065    writeln!(
4066        code,
4067        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4068    )?;
4069    writeln!(
4070        code,
4071        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4072    )?;
4073    writeln!(code, "        let r: u32 = rows as u32;")?;
4074    writeln!(code, "        let c: u32 = cols as u32;")?;
4075    writeln!(
4076        code,
4077        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4078    )?;
4079    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4080    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4081    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4082    writeln!(
4083        code,
4084        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4085    )?;
4086    writeln!(
4087        code,
4088        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4089    )?;
4090    writeln!(
4091        code,
4092        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4093    )?;
4094    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4095    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4096    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4097    writeln!(
4098        code,
4099        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4100    )?;
4101    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4102    writeln!(code, "    }}")?;
4103    writeln!(code)?;
4104
4105    // dispatch_matmul_q4
4106    writeln!(
4107        code,
4108        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4109    )?;
4110    writeln!(
4111        code,
4112        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4113    )?;
4114    writeln!(
4115        code,
4116        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4117    )?;
4118    writeln!(code, "        let r: u32 = rows as u32;")?;
4119    writeln!(code, "        let c: u32 = cols as u32;")?;
4120    writeln!(
4121        code,
4122        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4123    )?;
4124    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4125    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4126    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4127    writeln!(
4128        code,
4129        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4130    )?;
4131    writeln!(
4132        code,
4133        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4134    )?;
4135    writeln!(
4136        code,
4137        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4138    )?;
4139    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4140    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
4141    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4142    writeln!(
4143        code,
4144        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4145    )?;
4146    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4147    writeln!(code, "    }}")?;
4148    writeln!(code)?;
4149
4150    // dispatch_rope
4151    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
4152    writeln!(
4153        code,
4154        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4155    )?;
4156    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4157    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4158    writeln!(code, "        let p: u32 = pos as u32;")?;
4159    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4160    writeln!(
4161        code,
4162        "        let total_pairs = num_heads * (head_dim / 2);"
4163    )?;
4164    writeln!(
4165        code,
4166        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4167    )?;
4168    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4169    writeln!(
4170        code,
4171        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4172    )?;
4173    writeln!(
4174        code,
4175        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4176    )?;
4177    writeln!(
4178        code,
4179        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4180    )?;
4181    writeln!(
4182        code,
4183        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4184    )?;
4185    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4186    writeln!(
4187        code,
4188        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4189    )?;
4190    writeln!(
4191        code,
4192        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4193    )?;
4194    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4195    writeln!(code, "    }}")?;
4196    writeln!(code)?;
4197
4198    // dispatch_rope_offset
4199    writeln!(
4200        code,
4201        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4202    )?;
4203    writeln!(
4204        code,
4205        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4206    )?;
4207    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4208    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4209    writeln!(code, "        let p: u32 = pos as u32;")?;
4210    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4211    writeln!(
4212        code,
4213        "        let total_pairs = num_heads * (head_dim / 2);"
4214    )?;
4215    writeln!(
4216        code,
4217        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
4218    )?;
4219    writeln!(
4220        code,
4221        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
4222    )?;
4223    writeln!(
4224        code,
4225        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4226    )?;
4227    writeln!(
4228        code,
4229        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4230    )?;
4231    writeln!(
4232        code,
4233        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4234    )?;
4235    writeln!(
4236        code,
4237        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4238    )?;
4239    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4240    writeln!(
4241        code,
4242        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4243    )?;
4244    writeln!(
4245        code,
4246        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4247    )?;
4248    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4249    writeln!(code, "    }}")?;
4250    writeln!(code)?;
4251
4252    // dispatch_attention
4253    writeln!(
4254        code,
4255        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4256    )?;
4257    writeln!(
4258        code,
4259        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4260    )?;
4261    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4262    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4263    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4264    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4265    writeln!(
4266        code,
4267        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4268    )?;
4269    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4270    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4271    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4272    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4273    writeln!(
4274        code,
4275        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4276    )?;
4277    writeln!(
4278        code,
4279        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4280    )?;
4281    writeln!(
4282        code,
4283        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4284    )?;
4285    writeln!(
4286        code,
4287        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4288    )?;
4289    writeln!(
4290        code,
4291        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4292    )?;
4293    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4294    writeln!(
4295        code,
4296        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4297    )?;
4298    writeln!(
4299        code,
4300        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4301    )?;
4302    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4303    writeln!(code, "    }}")?;
4304    writeln!(code)?;
4305
4306    // dispatch_attention_offset
4307    writeln!(
4308        code,
4309        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4310    )?;
4311    writeln!(
4312        code,
4313        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4314    )?;
4315    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
4316    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4317    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4318    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4319    writeln!(
4320        code,
4321        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
4322    )?;
4323    writeln!(
4324        code,
4325        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
4326    )?;
4327    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4328    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4329    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4330    writeln!(
4331        code,
4332        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4333    )?;
4334    writeln!(
4335        code,
4336        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4337    )?;
4338    writeln!(
4339        code,
4340        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4341    )?;
4342    writeln!(
4343        code,
4344        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4345    )?;
4346    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4347    writeln!(
4348        code,
4349        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4350    )?;
4351    writeln!(
4352        code,
4353        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4354    )?;
4355    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4356    writeln!(code, "    }}")?;
4357    writeln!(code)?;
4358
4359    // dispatch_silu_mul
4360    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
4361    writeln!(
4362        code,
4363        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
4364    )?;
4365    writeln!(code, "        let count: u32 = n as u32;")?;
4366    writeln!(
4367        code,
4368        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
4369    )?;
4370    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
4371    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
4372    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4373    writeln!(
4374        code,
4375        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4376    )?;
4377    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4378    writeln!(
4379        code,
4380        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4381    )?;
4382    writeln!(
4383        code,
4384        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4385    )?;
4386    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4387    writeln!(code, "    }}")?;
4388    writeln!(code)?;
4389
4390    // dispatch_silu_mul_fused
4391    writeln!(
4392        code,
4393        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
4394    )?;
4395    writeln!(
4396        code,
4397        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
4398    )?;
4399    writeln!(
4400        code,
4401        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
4402    )?;
4403    writeln!(code, "        let count: u32 = n as u32;")?;
4404    writeln!(
4405        code,
4406        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
4407    )?;
4408    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
4409    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
4410    writeln!(
4411        code,
4412        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4413    )?;
4414    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4415    writeln!(
4416        code,
4417        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4418    )?;
4419    writeln!(
4420        code,
4421        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4422    )?;
4423    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4424    writeln!(code, "    }}")?;
4425    writeln!(code)?;
4426
4427    // dispatch_copy (simple src -> dst copy via compute kernel)
4428    writeln!(
4429        code,
4430        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
4431    )?;
4432    writeln!(
4433        code,
4434        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
4435    )?;
4436    writeln!(code, "        let n: u32 = count as u32;")?;
4437    writeln!(
4438        code,
4439        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4440    )?;
4441    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4442    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4443    writeln!(
4444        code,
4445        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4446    )?;
4447    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4448    writeln!(
4449        code,
4450        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4451    )?;
4452    writeln!(
4453        code,
4454        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4455    )?;
4456    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4457    writeln!(code, "    }}")?;
4458    writeln!(code)?;
4459
4460    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
4461    writeln!(
4462        code,
4463        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
4464    )?;
4465    writeln!(
4466        code,
4467        "    /// Used for embedding table lookup (copy a specific row)."
4468    )?;
4469    writeln!(
4470        code,
4471        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
4472    )?;
4473    writeln!(code, "        let off: u32 = src_offset as u32;")?;
4474    writeln!(code, "        let n: u32 = count as u32;")?;
4475    writeln!(
4476        code,
4477        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
4478    )?;
4479    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4480    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4481    writeln!(
4482        code,
4483        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
4484    )?;
4485    writeln!(
4486        code,
4487        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4488    )?;
4489    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4490    writeln!(
4491        code,
4492        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4493    )?;
4494    writeln!(
4495        code,
4496        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4497    )?;
4498    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4499    writeln!(code, "    }}")?;
4500    writeln!(code)?;
4501
4502    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
4503    writeln!(
4504        code,
4505        "    /// Dispatch copy from source at byte offset to destination at float offset."
4506    )?;
4507    writeln!(
4508        code,
4509        "    /// Used for KV cache updates from fused QKV buffer."
4510    )?;
4511    writeln!(
4512        code,
4513        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4514    )?;
4515    writeln!(code, "        let n: u32 = count as u32;")?;
4516    writeln!(
4517        code,
4518        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4519    )?;
4520    writeln!(
4521        code,
4522        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
4523    )?;
4524    writeln!(
4525        code,
4526        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4527    )?;
4528    writeln!(
4529        code,
4530        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4531    )?;
4532    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4533    writeln!(
4534        code,
4535        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4536    )?;
4537    writeln!(
4538        code,
4539        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4540    )?;
4541    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4542    writeln!(code, "    }}")?;
4543    writeln!(code)?;
4544
4545    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
4546    writeln!(
4547        code,
4548        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
4549    )?;
4550    writeln!(
4551        code,
4552        "    /// Used for KV cache updates (write to a specific position in the cache)."
4553    )?;
4554    writeln!(
4555        code,
4556        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
4557    )?;
4558    writeln!(code, "        let n: u32 = count as u32;")?;
4559    writeln!(
4560        code,
4561        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4562    )?;
4563    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4564    writeln!(
4565        code,
4566        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
4567    )?;
4568    writeln!(
4569        code,
4570        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4571    )?;
4572    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4573    writeln!(
4574        code,
4575        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4576    )?;
4577    writeln!(
4578        code,
4579        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4580    )?;
4581    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4582    writeln!(code, "    }}")?;
4583    writeln!(code)?;
4584
4585    // dispatch_add_inplace (residual connection, no blit needed)
4586    writeln!(
4587        code,
4588        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
4589    )?;
4590    writeln!(
4591        code,
4592        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
4593    )?;
4594    writeln!(code, "        let count: u32 = n as u32;")?;
4595    writeln!(
4596        code,
4597        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
4598    )?;
4599    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
4600    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
4601    writeln!(
4602        code,
4603        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4604    )?;
4605    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4606    writeln!(
4607        code,
4608        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4609    )?;
4610    writeln!(
4611        code,
4612        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4613    )?;
4614    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4615    writeln!(code, "    }}")?;
4616    writeln!(code)?;
4617
4618    // ── Batched prefill dispatch helpers ──
4619    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
4620    writeln!(code)?;
4621
4622    // dispatch_copy_embedding_batch
4623    writeln!(
4624        code,
4625        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
4626    )?;
4627    writeln!(
4628        code,
4629        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
4630    )?;
4631    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
4632    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4633    writeln!(
4634        code,
4635        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
4636    )?;
4637    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
4638    writeln!(
4639        code,
4640        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
4641    )?;
4642    writeln!(
4643        code,
4644        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
4645    )?;
4646    writeln!(
4647        code,
4648        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
4649    )?;
4650    writeln!(
4651        code,
4652        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4653    )?;
4654    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
4655    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4656    writeln!(
4657        code,
4658        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4659    )?;
4660    writeln!(
4661        code,
4662        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4663    )?;
4664    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4665    writeln!(code, "    }}")?;
4666    writeln!(code)?;
4667
4668    // dispatch_rms_norm_batch
4669    writeln!(
4670        code,
4671        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
4672    )?;
4673    writeln!(
4674        code,
4675        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
4676    )?;
4677    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4678    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4679    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4680    writeln!(
4681        code,
4682        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
4683    )?;
4684    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
4685    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4686    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4687    writeln!(
4688        code,
4689        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4690    )?;
4691    writeln!(
4692        code,
4693        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4694    )?;
4695    writeln!(
4696        code,
4697        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4698    )?;
4699    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4700    writeln!(
4701        code,
4702        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
4703    )?;
4704    writeln!(
4705        code,
4706        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4707    )?;
4708    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4709    writeln!(code, "    }}")?;
4710    writeln!(code)?;
4711
4712    // dispatch_matmul_batch (f32)
4713    writeln!(
4714        code,
4715        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4716    )?;
4717    writeln!(
4718        code,
4719        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4720    )?;
4721    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4722    writeln!(code, "        let r: u32 = rows as u32;")?;
4723    writeln!(code, "        let c: u32 = cols as u32;")?;
4724    writeln!(
4725        code,
4726        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
4727    )?;
4728    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4729    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4730    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4731    writeln!(
4732        code,
4733        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4734    )?;
4735    writeln!(
4736        code,
4737        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4738    )?;
4739    writeln!(
4740        code,
4741        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4742    )?;
4743    writeln!(
4744        code,
4745        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
4746    )?;
4747    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
4748    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4749    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4750    writeln!(
4751        code,
4752        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4753    )?;
4754    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4755    writeln!(code, "    }}")?;
4756    writeln!(code)?;
4757
4758    // dispatch_matmul_q8_batch
4759    writeln!(
4760        code,
4761        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4762    )?;
4763    writeln!(code, "    ///")?;
4764    writeln!(
4765        code,
4766        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
4767    )?;
4768    writeln!(
4769        code,
4770        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
4771    )?;
4772    writeln!(
4773        code,
4774        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4775    )?;
4776    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4777    writeln!(code, "        let r: u32 = rows as u32;")?;
4778    writeln!(code, "        let c: u32 = cols as u32;")?;
4779    writeln!(
4780        code,
4781        "        // Tile sizes must match the Metal shader constants."
4782    )?;
4783    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
4784    writeln!(code, "        const MMA_TOK_TILE: usize = 16;")?;
4785    writeln!(code, "        const MMA_ROW_TILE: usize = 16;")?;
4786    writeln!(code, "        const MMA32_TOK_TILE: usize = 32;")?;
4787    writeln!(code, "        const MMA32_ROW_TILE: usize = 32;")?;
4788    writeln!(
4789        code,
4790        "        // Hardware matrix-multiply paths (simdgroup_matrix)."
4791    )?;
4792    writeln!(
4793        code,
4794        "        // Prefer the large 32×32 tile when the problem supports it — halves"
4795    )?;
4796    writeln!(
4797        code,
4798        "        // dispatch count and reuses each weight load across 32 tokens."
4799    )?;
4800    writeln!(
4801        code,
4802        "        if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
4803    )?;
4804    writeln!(
4805        code,
4806        "            // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
4807    )?;
4808    writeln!(
4809        code,
4810        "            // It wins at moderate prefill lengths where the GPU is wave-starved,"
4811    )?;
4812    writeln!(
4813        code,
4814        "            // but the f32→f16 conversion overhead slightly hurts the small-hidden"
4815    )?;
4816    writeln!(
4817        code,
4818        "            // case (135M / 360M).  Switch at cols >= 2048 — a clean split that"
4819    )?;
4820    writeln!(
4821        code,
4822        "            // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
4823    )?;
4824    writeln!(code, "            let use_h4 = cols >= 2048;")?;
4825    writeln!(code, "            let pipe = if use_h4 {{")?;
4826    writeln!(code, "                &self.matmul_q8_mma32_h4_pipeline")?;
4827    writeln!(code, "            }} else {{")?;
4828    writeln!(code, "                &self.matmul_q8_mma32_pipeline")?;
4829    writeln!(code, "            }};")?;
4830    writeln!(
4831        code,
4832        "            enc.set_compute_pipeline_state(pipe);"
4833    )?;
4834    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
4835    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
4836    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
4837    writeln!(
4838        code,
4839        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4840    )?;
4841    writeln!(
4842        code,
4843        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4844    )?;
4845    writeln!(
4846        code,
4847        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4848    )?;
4849    writeln!(
4850        code,
4851        "            let row_tgs = rows / MMA32_ROW_TILE;"
4852    )?;
4853    writeln!(
4854        code,
4855        "            let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
4856    )?;
4857    writeln!(
4858        code,
4859        "            let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
4860    )?;
4861    writeln!(
4862        code,
4863        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
4864    )?;
4865    writeln!(
4866        code,
4867        "            enc.dispatch_thread_groups(grid_size, tg_size);"
4868    )?;
4869    writeln!(
4870        code,
4871        "        }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
4872    )?;
4873    writeln!(
4874        code,
4875        "            enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
4876    )?;
4877    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
4878    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
4879    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
4880    writeln!(
4881        code,
4882        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4883    )?;
4884    writeln!(
4885        code,
4886        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4887    )?;
4888    writeln!(
4889        code,
4890        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4891    )?;
4892    writeln!(
4893        code,
4894        "            let row_tgs = rows / MMA_ROW_TILE;"
4895    )?;
4896    writeln!(
4897        code,
4898        "            let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
4899    )?;
4900    writeln!(
4901        code,
4902        "            let tg_size = MTLSize::new(128, 1, 1);  // 4 simdgroups × 32 lanes"
4903    )?;
4904    writeln!(
4905        code,
4906        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
4907    )?;
4908    writeln!(
4909        code,
4910        "            enc.dispatch_thread_groups(grid_size, tg_size);"
4911    )?;
4912    writeln!(code, "        }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
4913    writeln!(
4914        code,
4915        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
4916    )?;
4917    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
4918    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
4919    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
4920    writeln!(
4921        code,
4922        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4923    )?;
4924    writeln!(
4925        code,
4926        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4927    )?;
4928    writeln!(
4929        code,
4930        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4931    )?;
4932    writeln!(
4933        code,
4934        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
4935    )?;
4936    writeln!(
4937        code,
4938        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
4939    )?;
4940    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
4941    writeln!(
4942        code,
4943        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
4944    )?;
4945    writeln!(
4946        code,
4947        "            enc.dispatch_thread_groups(grid_size, tg_size);"
4948    )?;
4949    writeln!(code, "        }} else {{")?;
4950    writeln!(
4951        code,
4952        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
4953    )?;
4954    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
4955    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
4956    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
4957    writeln!(
4958        code,
4959        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4960    )?;
4961    writeln!(
4962        code,
4963        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4964    )?;
4965    writeln!(
4966        code,
4967        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4968    )?;
4969    writeln!(
4970        code,
4971        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
4972    )?;
4973    writeln!(
4974        code,
4975        "            let num_tg = (row_tgs * num_tokens) as u64;"
4976    )?;
4977    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
4978    writeln!(
4979        code,
4980        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
4981    )?;
4982    writeln!(
4983        code,
4984        "            enc.dispatch_thread_groups(grid_size, tg_size);"
4985    )?;
4986    writeln!(code, "        }}")?;
4987    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4988    writeln!(code, "    }}")?;
4989    writeln!(code)?;
4990
4991    // dispatch_matmul_q4_batch
4992    writeln!(
4993        code,
4994        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4995    )?;
4996    writeln!(
4997        code,
4998        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4999    )?;
5000    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5001    writeln!(code, "        let r: u32 = rows as u32;")?;
5002    writeln!(code, "        let c: u32 = cols as u32;")?;
5003    writeln!(
5004        code,
5005        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5006    )?;
5007    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
5008    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
5009    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
5010    writeln!(
5011        code,
5012        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5013    )?;
5014    writeln!(
5015        code,
5016        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5017    )?;
5018    writeln!(
5019        code,
5020        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5021    )?;
5022    writeln!(
5023        code,
5024        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
5025    )?;
5026    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
5027    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5028    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5029    writeln!(
5030        code,
5031        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5032    )?;
5033    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5034    writeln!(code, "    }}")?;
5035    writeln!(code)?;
5036
5037    // dispatch_rope_batch
5038    writeln!(
5039        code,
5040        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5041    )?;
5042    writeln!(
5043        code,
5044        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5045    )?;
5046    writeln!(
5047        code,
5048        "    /// `row_stride` is the number of floats per token row in the batch buffer."
5049    )?;
5050    writeln!(
5051        code,
5052        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5053    )?;
5054    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
5055    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
5056    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5057    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5058    writeln!(
5059        code,
5060        "        let pairs_per_token = num_heads * (head_dim / 2);"
5061    )?;
5062    writeln!(
5063        code,
5064        "        let total_pairs = num_tokens * pairs_per_token;"
5065    )?;
5066    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
5067    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
5068    // we need to pass the buffer at the right byte offset for each token's data.
5069    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
5070    // but our layout is data[token * row_stride + data_float_offset + ...].
5071    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
5072    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
5073    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
5074    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
5075    //
5076    // Simplest approach: use the single-token rope kernel for each token in a loop.
5077    // This is still efficient because we're dispatching all within the same command encoder.
5078    writeln!(
5079        code,
5080        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5081    )?;
5082    writeln!(code, "        for t in 0..num_tokens {{")?;
5083    writeln!(
5084        code,
5085        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5086    )?;
5087    writeln!(
5088        code,
5089        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5090    )?;
5091    writeln!(
5092        code,
5093        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
5094    )?;
5095    writeln!(
5096        code,
5097        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
5098    )?;
5099    writeln!(
5100        code,
5101        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5102    )?;
5103    writeln!(
5104        code,
5105        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5106    )?;
5107    writeln!(
5108        code,
5109        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5110    )?;
5111    writeln!(
5112        code,
5113        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5114    )?;
5115    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
5116    writeln!(
5117        code,
5118        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5119    )?;
5120    writeln!(
5121        code,
5122        "            enc.dispatch_thread_groups(grid_size, tg_size);"
5123    )?;
5124    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5125    writeln!(code, "        }}")?;
5126    writeln!(code, "    }}")?;
5127    writeln!(code)?;
5128
5129    // dispatch_silu_mul_fused_batch
5130    writeln!(
5131        code,
5132        "    /// Dispatch batched fused SiLU-multiply for M tokens."
5133    )?;
5134    writeln!(
5135        code,
5136        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5137    )?;
5138    writeln!(code, "        let count: u32 = n as u32;")?;
5139    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
5140    writeln!(
5141        code,
5142        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5143    )?;
5144    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
5145    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
5146    writeln!(
5147        code,
5148        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5149    )?;
5150    writeln!(
5151        code,
5152        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5153    )?;
5154    writeln!(code, "        let total = num_tokens * n;")?;
5155    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5156    writeln!(
5157        code,
5158        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5159    )?;
5160    writeln!(
5161        code,
5162        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5163    )?;
5164    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5165    writeln!(code, "    }}")?;
5166    writeln!(code)?;
5167
5168    // dispatch_add_inplace_batch_n (add n elements in-place)
5169    writeln!(
5170        code,
5171        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5172    )?;
5173    writeln!(
5174        code,
5175        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5176    )?;
5177    writeln!(code, "        let count: u32 = total_n as u32;")?;
5178    writeln!(
5179        code,
5180        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5181    )?;
5182    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
5183    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
5184    writeln!(
5185        code,
5186        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5187    )?;
5188    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5189    writeln!(
5190        code,
5191        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5192    )?;
5193    writeln!(
5194        code,
5195        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5196    )?;
5197    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5198    writeln!(code, "    }}")?;
5199    writeln!(code)?;
5200
5201    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
5202    writeln!(
5203        code,
5204        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
5205    )?;
5206    writeln!(
5207        code,
5208        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5209    )?;
5210    writeln!(code, "        let n: u32 = count as u32;")?;
5211    writeln!(
5212        code,
5213        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5214    )?;
5215    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5216    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5217    writeln!(
5218        code,
5219        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5220    )?;
5221    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5222    writeln!(
5223        code,
5224        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5225    )?;
5226    writeln!(
5227        code,
5228        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5229    )?;
5230    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5231    writeln!(code, "    }}")?;
5232    writeln!(code)?;
5233
5234    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
5235    writeln!(
5236        code,
5237        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
5238    )?;
5239    writeln!(
5240        code,
5241        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5242    )?;
5243    writeln!(code, "        let n: u32 = count as u32;")?;
5244    writeln!(
5245        code,
5246        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5247    )?;
5248    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5249    writeln!(
5250        code,
5251        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5252    )?;
5253    writeln!(
5254        code,
5255        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5256    )?;
5257    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5258    writeln!(
5259        code,
5260        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5261    )?;
5262    writeln!(
5263        code,
5264        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5265    )?;
5266    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5267    writeln!(code, "    }}")?;
5268    writeln!(code)?;
5269
5270    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
5271    writeln!(
5272        code,
5273        "    /// Copy from src at byte offset to dst at float offset."
5274    )?;
5275    writeln!(
5276        code,
5277        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5278    )?;
5279    writeln!(code, "        let n: u32 = count as u32;")?;
5280    writeln!(
5281        code,
5282        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
5283    )?;
5284    writeln!(
5285        code,
5286        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5287    )?;
5288    writeln!(
5289        code,
5290        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5291    )?;
5292    writeln!(
5293        code,
5294        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5295    )?;
5296    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5297    writeln!(
5298        code,
5299        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5300    )?;
5301    writeln!(
5302        code,
5303        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5304    )?;
5305    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5306    writeln!(code, "    }}")?;
5307    writeln!(code)?;
5308
5309    // dispatch_copy_kv_batch
5310    writeln!(
5311        code,
5312        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
5313    )?;
5314    writeln!(
5315        code,
5316        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
5317    )?;
5318    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5319    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
5320    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5321    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
5322    writeln!(code, "        let so: u32 = src_offset as u32;")?;
5323    writeln!(
5324        code,
5325        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
5326    )?;
5327    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5328    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
5329    writeln!(
5330        code,
5331        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5332    )?;
5333    writeln!(
5334        code,
5335        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
5336    )?;
5337    writeln!(
5338        code,
5339        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5340    )?;
5341    writeln!(
5342        code,
5343        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
5344    )?;
5345    writeln!(
5346        code,
5347        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
5348    )?;
5349    writeln!(code, "        let total = num_tokens * kv_dim;")?;
5350    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5351    writeln!(
5352        code,
5353        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5354    )?;
5355    writeln!(
5356        code,
5357        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5358    )?;
5359    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5360    writeln!(code, "    }}")?;
5361    writeln!(code)?;
5362
5363    // dispatch_attention_batch
5364    writeln!(
5365        code,
5366        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
5367    )?;
5368    writeln!(
5369        code,
5370        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
5371    )?;
5372    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5373    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5374    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
5375    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5376    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5377    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
5378    writeln!(
5379        code,
5380        "        enc.set_compute_pipeline_state(&self.attention_batch_pipeline);"
5381    )?;
5382    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
5383    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
5384    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
5385    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
5386    writeln!(
5387        code,
5388        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5389    )?;
5390    writeln!(
5391        code,
5392        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5393    )?;
5394    writeln!(
5395        code,
5396        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5397    )?;
5398    writeln!(
5399        code,
5400        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5401    )?;
5402    writeln!(
5403        code,
5404        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5405    )?;
5406    writeln!(
5407        code,
5408        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5409    )?;
5410    writeln!(
5411        code,
5412        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
5413    )?;
5414    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5415    writeln!(
5416        code,
5417        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
5418    )?;
5419    writeln!(
5420        code,
5421        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5422    )?;
5423    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5424    writeln!(code, "    }}")?;
5425    writeln!(code)?;
5426
5427    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
5428    writeln!(
5429        code,
5430        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
5431    )?;
5432    writeln!(
5433        code,
5434        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
5435    )?;
5436    writeln!(
5437        code,
5438        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
5439    )?;
5440    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5441    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5442    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
5443    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
5444    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
5445    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
5446    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
5447    writeln!(
5448        code,
5449        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
5450    )?;
5451    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
5452    writeln!(
5453        code,
5454        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5455    )?;
5456    writeln!(
5457        code,
5458        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5459    )?;
5460    writeln!(
5461        code,
5462        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
5463    )?;
5464    writeln!(
5465        code,
5466        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5467    )?;
5468    writeln!(
5469        code,
5470        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5471    )?;
5472    writeln!(
5473        code,
5474        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5475    )?;
5476    writeln!(
5477        code,
5478        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5479    )?;
5480    writeln!(
5481        code,
5482        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
5483    )?;
5484    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5485    writeln!(
5486        code,
5487        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5488    )?;
5489    writeln!(
5490        code,
5491        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5492    )?;
5493    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5494    writeln!(code, "    }}")?;
5495    writeln!(code)?;
5496
5497    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
5498    writeln!(
5499        code,
5500        "    /// Dispatch fused K+V cache copy in one kernel launch."
5501    )?;
5502    writeln!(
5503        code,
5504        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
5505    )?;
5506    writeln!(
5507        code,
5508        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
5509    )?;
5510    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
5511    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
5512    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
5513    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
5514    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
5515    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
5516    writeln!(
5517        code,
5518        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
5519    )?;
5520    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
5521    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
5522    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
5523    writeln!(
5524        code,
5525        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5526    )?;
5527    writeln!(
5528        code,
5529        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
5530    )?;
5531    writeln!(
5532        code,
5533        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5534    )?;
5535    writeln!(
5536        code,
5537        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
5538    )?;
5539    writeln!(
5540        code,
5541        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
5542    )?;
5543    writeln!(
5544        code,
5545        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
5546    )?;
5547    writeln!(
5548        code,
5549        "        let total = num_tokens * kv_dim * 2;  // K + V"
5550    )?;
5551    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
5552    writeln!(
5553        code,
5554        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5555    )?;
5556    writeln!(
5557        code,
5558        "        enc.dispatch_thread_groups(grid_size, tg_size);"
5559    )?;
5560    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5561    writeln!(code, "    }}")?;
5562
5563    writeln!(code, "}}")?;
5564    writeln!(code)?;
5565
5566    Ok(())
5567}
5568
5569fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
5570    writeln!(
5571        code,
5572        "// ── Helper functions ──────────────────────────────────"
5573    )?;
5574    writeln!(code)?;
5575    writeln!(
5576        code,
5577        "/// Create a compute pipeline from a named function in the Metal library."
5578    )?;
5579    writeln!(
5580        code,
5581        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
5582    )?;
5583    writeln!(
5584        code,
5585        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
5586    )?;
5587    writeln!(
5588        code,
5589        "    device.new_compute_pipeline_state_with_function(&func)"
5590    )?;
5591    writeln!(
5592        code,
5593        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
5594    )?;
5595    writeln!(code, "}}")?;
5596    writeln!(code)?;
5597
5598    Ok(())
5599}
5600
5601// ---------------------------------------------------------------------------
5602// main.rs generation
5603// ---------------------------------------------------------------------------
5604
5605fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
5606    let _sanitized = sanitize_name(model_name);
5607    let _vocab = config.vocab_size;
5608
5609    let mut code = String::with_capacity(16 * 1024);
5610    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
5611    writeln!(
5612        code,
5613        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
5614    )?;
5615    writeln!(code)?;
5616    writeln!(code, "mod model;")?;
5617    writeln!(code)?;
5618    writeln!(code, "use std::io::Write;")?;
5619    writeln!(code, "use std::time::Instant;")?;
5620    writeln!(code, "use serde::Deserialize;")?;
5621    writeln!(code)?;
5622
5623    // -- main function --
5624    writeln!(code, "fn main() {{")?;
5625    writeln!(
5626        code,
5627        "    let args: Vec<String> = std::env::args().collect();"
5628    )?;
5629    writeln!(code)?;
5630    writeln!(
5631        code,
5632        "    // Detect --serve mode (only requires weights + tokenizer)"
5633    )?;
5634    writeln!(
5635        code,
5636        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
5637    )?;
5638    writeln!(code)?;
5639    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
5640    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
5641    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
5642    writeln!(code, "        std::process::exit(1);")?;
5643    writeln!(code, "    }}")?;
5644    writeln!(code)?;
5645    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
5646    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
5647    writeln!(code, "        std::process::exit(1);")?;
5648    writeln!(code, "    }}")?;
5649    writeln!(code)?;
5650    writeln!(code, "    let weights_path = &args[1];")?;
5651    writeln!(code, "    let tokenizer_path = &args[2];")?;
5652    writeln!(code)?;
5653    writeln!(code, "    // Parse optional flags")?;
5654    writeln!(code, "    let mut max_tokens: usize = 128;")?;
5655    writeln!(code, "    let mut port: u16 = 8080;")?;
5656    writeln!(
5657        code,
5658        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
5659    )?;
5660    writeln!(
5661        code,
5662        "    let profile = args.iter().any(|a| a == \"--profile\");"
5663    )?;
5664    writeln!(code, "    let mut i = 3;")?;
5665    writeln!(code, "    while i < args.len() {{")?;
5666    writeln!(
5667        code,
5668        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
5669    )?;
5670    writeln!(
5671        code,
5672        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
5673    )?;
5674    writeln!(code, "            i += 2;")?;
5675    writeln!(
5676        code,
5677        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
5678    )?;
5679    writeln!(
5680        code,
5681        "            port = args[i + 1].parse().unwrap_or(8080);"
5682    )?;
5683    writeln!(code, "            i += 2;")?;
5684    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
5685    writeln!(code, "            i += 1;")?;
5686    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
5687    writeln!(code, "            i += 1;")?;
5688    writeln!(code, "        }} else {{")?;
5689    writeln!(code, "            i += 1;")?;
5690    writeln!(code, "        }}")?;
5691    writeln!(code, "    }}")?;
5692    writeln!(code)?;
5693
5694    // -- load model (shared by both modes) --
5695    writeln!(
5696        code,
5697        "    // Memory-map weights for zero-copy loading on Apple Silicon"
5698    )?;
5699    writeln!(
5700        code,
5701        "    let weights_file = std::fs::File::open(weights_path)"
5702    )?;
5703    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
5704    writeln!(
5705        code,
5706        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
5707    )?;
5708    writeln!(code)?;
5709    writeln!(code, "    // Load tokenizer")?;
5710    writeln!(
5711        code,
5712        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
5713    )?;
5714    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
5715    writeln!(code)?;
5716    writeln!(code, "    // Create Metal model")?;
5717    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
5718    writeln!(
5719        code,
5720        "    let mut model = model::MetalModel::new(&weights_mmap);"
5721    )?;
5722    writeln!(code)?;
5723
5724    // -- branch: serve vs CLI --
5725    writeln!(code, "    if serve_mode {{")?;
5726    writeln!(code, "        serve(model, tokenizer, port);")?;
5727    writeln!(code, "    }} else {{")?;
5728    writeln!(code, "        let prompt = &args[3];")?;
5729    writeln!(
5730        code,
5731        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
5732    )?;
5733    writeln!(code, "    }}")?;
5734    writeln!(code, "}}")?;
5735    writeln!(code)?;
5736
5737    // -- cli_mode function --
5738    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
5739    writeln!(code, "    // Tokenize prompt")?;
5740    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
5741    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
5742    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
5743    writeln!(code)?;
5744    writeln!(
5745        code,
5746        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
5747    )?;
5748    writeln!(
5749        code,
5750        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
5751    )?;
5752    writeln!(
5753        code,
5754        "    // The last token uses synchronous forward() to get logits."
5755    )?;
5756    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
5757    writeln!(code, "    let prefill_start = Instant::now();")?;
5758    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
5759    writeln!(
5760        code,
5761        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
5762    )?;
5763    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
5764    writeln!(code, "    }} else {{")?;
5765    writeln!(code, "        model.forward(prompt_tokens[0])")?;
5766    writeln!(code, "    }};")?;
5767    writeln!(
5768        code,
5769        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
5770    )?;
5771    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
5772    writeln!(
5773        code,
5774        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
5775    )?;
5776    writeln!(
5777        code,
5778        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
5779    )?;
5780    writeln!(code)?;
5781    writeln!(code, "    // Generate tokens")?;
5782    writeln!(code, "    let mut next_token = argmax(&logits);")?;
5783    writeln!(code, "    let gen_start = Instant::now();")?;
5784    writeln!(code, "    let mut generated_count: usize = 0;")?;
5785    writeln!(code)?;
5786    writeln!(code, "    for _ in 0..max_tokens {{")?;
5787    writeln!(
5788        code,
5789        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
5790    )?;
5791    writeln!(code, "            if !quiet {{")?;
5792    writeln!(code, "                print!(\"{{}}\", text);")?;
5793    writeln!(code, "                std::io::stdout().flush().ok();")?;
5794    writeln!(code, "            }}")?;
5795    writeln!(code, "        }}")?;
5796    writeln!(code, "        generated_count += 1;")?;
5797    writeln!(code)?;
5798    writeln!(
5799        code,
5800        "        // Use profiling forward for first token when --profile is set"
5801    )?;
5802    writeln!(
5803        code,
5804        "        let logits = if profile && generated_count == 1 {{"
5805    )?;
5806    writeln!(code, "            model.forward_profile(next_token)")?;
5807    writeln!(code, "        }} else {{")?;
5808    writeln!(code, "            model.forward(next_token)")?;
5809    writeln!(code, "        }};")?;
5810    writeln!(code, "        next_token = argmax(&logits);")?;
5811    writeln!(code)?;
5812    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
5813    writeln!(code, "        if next_token == 2 {{")?;
5814    writeln!(code, "            break;")?;
5815    writeln!(code, "        }}")?;
5816    writeln!(code)?;
5817    writeln!(
5818        code,
5819        "        // Yield between tokens to reduce sustained GPU thermal load."
5820    )?;
5821    writeln!(
5822        code,
5823        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
5824    )?;
5825    writeln!(
5826        code,
5827        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
5828    )?;
5829    writeln!(
5830        code,
5831        "        // briefly, providing a micro-break that helps sustain peak throughput."
5832    )?;
5833    writeln!(code, "        std::thread::yield_now();")?;
5834    writeln!(code, "    }}")?;
5835    writeln!(code, "    if !quiet {{")?;
5836    writeln!(code, "        println!();")?;
5837    writeln!(code, "    }}")?;
5838    writeln!(
5839        code,
5840        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
5841    )?;
5842    writeln!(
5843        code,
5844        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
5845    )?;
5846    writeln!(
5847        code,
5848        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
5849    )?;
5850    writeln!(code, "}}")?;
5851    writeln!(code)?;
5852
5853    // -- argmax helper --
5854    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
5855    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
5856    writeln!(code, "    logits.iter()")?;
5857    writeln!(code, "        .enumerate()")?;
5858    writeln!(
5859        code,
5860        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
5861    )?;
5862    writeln!(code, "        .map(|(i, _)| i as u32)")?;
5863    writeln!(code, "        .unwrap_or(0)")?;
5864    writeln!(code, "}}")?;
5865    writeln!(code)?;
5866
5867    // -- Request/Response types for OpenAI API --
5868    writeln!(
5869        code,
5870        "// -----------------------------------------------------------------------"
5871    )?;
5872    writeln!(code, "// OpenAI-compatible API server")?;
5873    writeln!(
5874        code,
5875        "// -----------------------------------------------------------------------"
5876    )?;
5877    writeln!(code)?;
5878    writeln!(code, "#[derive(Deserialize)]")?;
5879    writeln!(code, "struct ChatRequest {{")?;
5880    writeln!(code, "    messages: Vec<ChatMessage>,")?;
5881    writeln!(code, "    #[serde(default)]")?;
5882    writeln!(code, "    stream: Option<bool>,")?;
5883    writeln!(code, "    #[serde(default)]")?;
5884    writeln!(code, "    max_tokens: Option<usize>,")?;
5885    writeln!(code, "    #[serde(default)]")?;
5886    writeln!(code, "    temperature: Option<f32>,")?;
5887    writeln!(code, "    #[serde(default)]")?;
5888    writeln!(code, "    model: Option<String>,")?;
5889    writeln!(code, "}}")?;
5890    writeln!(code)?;
5891    writeln!(code, "#[derive(Deserialize)]")?;
5892    writeln!(code, "struct ChatMessage {{")?;
5893    writeln!(code, "    role: String,")?;
5894    writeln!(code, "    content: String,")?;
5895    writeln!(code, "}}")?;
5896    writeln!(code)?;
5897
5898    // -- format_chat_messages --
5899    writeln!(
5900        code,
5901        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
5902    )?;
5903    writeln!(code, "    let mut prompt = String::new();")?;
5904    writeln!(code, "    for msg in messages {{")?;
5905    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
5906    writeln!(code, "    }}")?;
5907    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
5908    writeln!(code, "    prompt")?;
5909    writeln!(code, "}}")?;
5910    writeln!(code)?;
5911
5912    // -- prefill helper --
5913    writeln!(
5914        code,
5915        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
5916    )?;
5917    writeln!(code, "    let len = tokens.len();")?;
5918    writeln!(code, "    if len > 1 {{")?;
5919    writeln!(
5920        code,
5921        "        model.forward_prefill_batch(&tokens[..len - 1]);"
5922    )?;
5923    writeln!(code, "    }}")?;
5924    writeln!(code, "    model.forward(tokens[len - 1])")?;
5925    writeln!(code, "}}")?;
5926    writeln!(code)?;
5927
5928    // -- serve function --
5929    writeln!(
5930        code,
5931        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
5932    )?;
5933    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
5934    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
5935    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
5936    writeln!(
5937        code,
5938        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
5939    )?;
5940    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
5941    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
5942    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
5943    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
5944    writeln!(code)?;
5945    writeln!(code, "    for request in server.incoming_requests() {{")?;
5946    writeln!(code, "        let method = request.method().to_string();")?;
5947    writeln!(code, "        let url = request.url().to_string();")?;
5948    writeln!(code)?;
5949    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
5950
5951    // -- POST /v1/chat/completions --
5952    writeln!(
5953        code,
5954        "            (\"POST\", \"/v1/chat/completions\") => {{"
5955    )?;
5956    writeln!(
5957        code,
5958        "                handle_chat_completion(&mut model, &tokenizer, request);"
5959    )?;
5960    writeln!(code, "            }}")?;
5961
5962    // -- GET /v1/models --
5963    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
5964    writeln!(code, "                let body = serde_json::json!({{")?;
5965    writeln!(code, "                    \"object\": \"list\",")?;
5966    writeln!(code, "                    \"data\": [{{")?;
5967    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
5968    writeln!(code, "                        \"object\": \"model\",")?;
5969    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
5970    writeln!(code, "                    }}]")?;
5971    writeln!(code, "                }});")?;
5972    writeln!(
5973        code,
5974        "                let resp = tiny_http::Response::from_string(body.to_string())"
5975    )?;
5976    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
5977    writeln!(code, "                request.respond(resp).ok();")?;
5978    writeln!(code, "            }}")?;
5979
5980    // -- GET /health --
5981    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
5982    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
5983    writeln!(code, "                request.respond(resp).ok();")?;
5984    writeln!(code, "            }}")?;
5985
5986    // -- 404 --
5987    writeln!(code, "            _ => {{")?;
5988    writeln!(
5989        code,
5990        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
5991    )?;
5992    writeln!(code, "                    .with_status_code(404);")?;
5993    writeln!(code, "                request.respond(resp).ok();")?;
5994    writeln!(code, "            }}")?;
5995    writeln!(code, "        }}")?;
5996    writeln!(code, "    }}")?;
5997    writeln!(code, "}}")?;
5998    writeln!(code)?;
5999
6000    // -- handle_chat_completion --
6001    writeln!(code, "fn handle_chat_completion(")?;
6002    writeln!(code, "    model: &mut model::MetalModel,")?;
6003    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
6004    writeln!(code, "    mut request: tiny_http::Request,")?;
6005    writeln!(code, ") {{")?;
6006    writeln!(code, "    // Read request body")?;
6007    writeln!(code, "    let mut body = String::new();")?;
6008    writeln!(
6009        code,
6010        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
6011    )?;
6012    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6013    writeln!(code, "            .with_status_code(400);")?;
6014    writeln!(code, "        request.respond(resp).ok();")?;
6015    writeln!(code, "        return;")?;
6016    writeln!(code, "    }}")?;
6017    writeln!(code)?;
6018    writeln!(code, "    // Parse JSON")?;
6019    writeln!(
6020        code,
6021        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
6022    )?;
6023    writeln!(code, "        Ok(r) => r,")?;
6024    writeln!(code, "        Err(e) => {{")?;
6025    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6026    writeln!(code, "                .with_status_code(400);")?;
6027    writeln!(code, "            request.respond(resp).ok();")?;
6028    writeln!(code, "            return;")?;
6029    writeln!(code, "        }}")?;
6030    writeln!(code, "    }};")?;
6031    writeln!(code)?;
6032    writeln!(
6033        code,
6034        "    let prompt = format_chat_messages(&req.messages);"
6035    )?;
6036    writeln!(
6037        code,
6038        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6039    )?;
6040    writeln!(code, "        Ok(e) => e,")?;
6041    writeln!(code, "        Err(e) => {{")?;
6042    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6043    writeln!(code, "                .with_status_code(500);")?;
6044    writeln!(code, "            request.respond(resp).ok();")?;
6045    writeln!(code, "            return;")?;
6046    writeln!(code, "        }}")?;
6047    writeln!(code, "    }};")?;
6048    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
6049    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
6050    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
6051    writeln!(
6052        code,
6053        "    let _temperature = req.temperature.unwrap_or(1.0);"
6054    )?;
6055    writeln!(code)?;
6056
6057    // -- Reset KV cache for each request --
6058    writeln!(code, "    model.reset();")?;
6059    writeln!(code)?;
6060
6061    // -- Prefill with timing --
6062    writeln!(code, "    let prefill_start = Instant::now();")?;
6063    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
6064    writeln!(
6065        code,
6066        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6067    )?;
6068    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
6069    writeln!(code, "    let mut next_token = argmax(&logits);")?;
6070    writeln!(code)?;
6071
6072    writeln!(code, "    if stream {{")?;
6073
6074    // -- SSE streaming response --
6075    writeln!(
6076        code,
6077        "        // SSE streaming: generate tokens and build SSE body"
6078    )?;
6079    writeln!(code, "        let gen_start = Instant::now();")?;
6080    writeln!(code, "        let mut generated_count: usize = 0;")?;
6081    writeln!(code, "        let mut sse_body = String::new();")?;
6082    writeln!(code, "        for _ in 0..max_tokens {{")?;
6083    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6084    writeln!(
6085        code,
6086        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6087    )?;
6088    writeln!(
6089        code,
6090        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
6091    )?;
6092    writeln!(
6093        code,
6094        "                // escaped includes surrounding quotes, strip them"
6095    )?;
6096    writeln!(
6097        code,
6098        "                let inner = &escaped[1..escaped.len()-1];"
6099    )?;
6100    writeln!(code, "                sse_body.push_str(&format!(")?;
6101    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6102    writeln!(code, "                    inner")?;
6103    writeln!(code, "                ));")?;
6104    writeln!(code, "            }}")?;
6105    writeln!(code, "            generated_count += 1;")?;
6106    writeln!(code, "            let logits = model.forward(next_token);")?;
6107    writeln!(code, "            next_token = argmax(&logits);")?;
6108    writeln!(code, "        }}")?;
6109    writeln!(
6110        code,
6111        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6112    )?;
6113    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6114    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
6115    writeln!(code)?;
6116    writeln!(
6117        code,
6118        "        // Final chunk with finish_reason, timing, and DONE sentinel"
6119    )?;
6120    writeln!(code, "        sse_body.push_str(&format!(")?;
6121    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6122    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6123    writeln!(code, "        ));")?;
6124    writeln!(code)?;
6125    writeln!(
6126        code,
6127        "        let resp = tiny_http::Response::from_string(sse_body)"
6128    )?;
6129    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6130    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6131    writeln!(code, "        request.respond(resp).ok();")?;
6132
6133    writeln!(code, "    }} else {{")?;
6134
6135    // -- Non-streaming response --
6136    writeln!(
6137        code,
6138        "        // Non-streaming: generate all tokens, return JSON"
6139    )?;
6140    writeln!(code, "        let gen_start = Instant::now();")?;
6141    writeln!(code, "        let mut generated_count: usize = 0;")?;
6142    writeln!(code, "        let mut generated = String::new();")?;
6143    writeln!(code, "        for _ in 0..max_tokens {{")?;
6144    writeln!(code, "            if next_token == 2 {{ break; }}")?;
6145    writeln!(
6146        code,
6147        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6148    )?;
6149    writeln!(code, "                generated.push_str(&text);")?;
6150    writeln!(code, "            }}")?;
6151    writeln!(code, "            generated_count += 1;")?;
6152    writeln!(code, "            let logits = model.forward(next_token);")?;
6153    writeln!(code, "            next_token = argmax(&logits);")?;
6154    writeln!(code, "        }}")?;
6155    writeln!(
6156        code,
6157        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6158    )?;
6159    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6160    writeln!(code)?;
6161    writeln!(code, "        let resp_json = serde_json::json!({{")?;
6162    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
6163    writeln!(code, "            \"object\": \"chat.completion\",")?;
6164    writeln!(code, "            \"choices\": [{{")?;
6165    writeln!(code, "                \"index\": 0,")?;
6166    writeln!(code, "                \"message\": {{")?;
6167    writeln!(code, "                    \"role\": \"assistant\",")?;
6168    writeln!(code, "                    \"content\": generated")?;
6169    writeln!(code, "                }},")?;
6170    writeln!(code, "                \"finish_reason\": \"stop\"")?;
6171    writeln!(code, "            }}],")?;
6172    writeln!(code, "            \"usage\": {{")?;
6173    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
6174    writeln!(
6175        code,
6176        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
6177    )?;
6178    writeln!(
6179        code,
6180        "                \"generation_tokens\": generated_count,"
6181    )?;
6182    writeln!(
6183        code,
6184        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
6185    )?;
6186    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
6187    writeln!(code, "            }}")?;
6188    writeln!(code, "        }});")?;
6189    writeln!(
6190        code,
6191        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
6192    )?;
6193    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6194    writeln!(code, "        request.respond(resp).ok();")?;
6195    writeln!(code, "    }}")?;
6196    writeln!(code, "}}")?;
6197
6198    Ok(code)
6199}
6200
6201// ---------------------------------------------------------------------------
6202// Tests
6203// ---------------------------------------------------------------------------
6204
6205#[cfg(test)]
6206mod tests {
6207    use super::*;
6208    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
6209
6210    fn minimal_config() -> ModelConfig {
6211        ModelConfig {
6212            architecture: Architecture::Llama,
6213            hidden_size: 64,
6214            intermediate_size: 128,
6215            num_layers: 2,
6216            num_attention_heads: 4,
6217            num_kv_heads: 4,
6218            head_dim: 16,
6219            vocab_size: 256,
6220            max_seq_len: 512,
6221            rms_norm_eps: 1e-5,
6222            rope_theta: 10000.0,
6223            dtype: DType::F32,
6224            sliding_window_size: None,
6225            qkv_bias: false,
6226        }
6227    }
6228
6229    fn minimal_graph() -> Graph {
6230        Graph::new("test-metal").with_config(minimal_config())
6231    }
6232
6233    #[test]
6234    fn generate_metal_project_creates_files() {
6235        let dir = tempfile::tempdir().unwrap();
6236        let graph = minimal_graph();
6237        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
6238
6239        assert!(
6240            dir.path().join("Cargo.toml").exists(),
6241            "Cargo.toml should be created"
6242        );
6243        assert!(
6244            dir.path().join("src/model.rs").exists(),
6245            "src/model.rs should be created"
6246        );
6247        assert!(
6248            dir.path().join("src/main.rs").exists(),
6249            "src/main.rs should be created"
6250        );
6251        assert!(
6252            dir.path().join("shaders/kernels.metal").exists(),
6253            "shaders/kernels.metal should be created"
6254        );
6255    }
6256
6257    #[test]
6258    fn generated_cargo_toml_has_metal_dep() {
6259        let toml = generate_cargo_toml("my-model");
6260        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
6261        assert!(
6262            toml.contains("tokenizers"),
6263            "Cargo.toml should depend on tokenizers"
6264        );
6265        assert!(
6266            toml.contains("memmap2"),
6267            "Cargo.toml should depend on memmap2"
6268        );
6269        assert!(toml.contains("half"), "Cargo.toml should depend on half");
6270    }
6271
6272    #[test]
6273    fn generated_model_rs_contains_metal_code() {
6274        let config = minimal_config();
6275        let model_rs = generate_model_rs(&config).unwrap();
6276
6277        assert!(
6278            model_rs.contains("pub struct MetalModel"),
6279            "model.rs should define MetalModel struct"
6280        );
6281        assert!(
6282            model_rs.contains("matmul_pipeline: ComputePipelineState"),
6283            "MetalModel should have matmul_pipeline field"
6284        );
6285        assert!(
6286            model_rs.contains("Device::system_default()"),
6287            "model.rs should use Metal device"
6288        );
6289        assert!(
6290            model_rs.contains("new_library_with_source"),
6291            "model.rs should compile Metal shaders"
6292        );
6293        assert!(
6294            model_rs.contains("fn new(weights: &[u8])"),
6295            "MetalModel should implement new()"
6296        );
6297        assert!(
6298            model_rs.contains("fn forward(&mut self, token_id: u32)"),
6299            "MetalModel should implement forward()"
6300        );
6301    }
6302
6303    #[test]
6304    fn generated_shaders_contain_kernel_names() {
6305        let shaders = generate_metal_shaders(&minimal_config());
6306
6307        assert!(
6308            shaders.contains("kernel void matmul_vec"),
6309            "shaders should contain matmul_vec kernel"
6310        );
6311        assert!(
6312            shaders.contains("kernel void rms_norm"),
6313            "shaders should contain rms_norm kernel"
6314        );
6315        assert!(
6316            shaders.contains("kernel void rope"),
6317            "shaders should contain rope kernel"
6318        );
6319        assert!(
6320            shaders.contains("kernel void softmax"),
6321            "shaders should contain softmax kernel"
6322        );
6323        assert!(
6324            shaders.contains("kernel void silu_mul("),
6325            "shaders should contain silu_mul kernel"
6326        );
6327        assert!(
6328            shaders.contains("kernel void silu_mul_fused"),
6329            "shaders should contain silu_mul_fused kernel"
6330        );
6331        assert!(
6332            shaders.contains("kernel void elementwise_add"),
6333            "shaders should contain elementwise_add kernel"
6334        );
6335        assert!(
6336            shaders.contains("kernel void attention"),
6337            "shaders should contain attention kernel"
6338        );
6339        assert!(
6340            shaders.contains("kernel void add_inplace"),
6341            "shaders should contain add_inplace kernel"
6342        );
6343        assert!(
6344            shaders.contains("kernel void copy_buffer"),
6345            "shaders should contain copy_buffer kernel"
6346        );
6347        assert!(
6348            shaders.contains("kernel void copy_offset"),
6349            "shaders should contain copy_offset kernel"
6350        );
6351    }
6352
6353    #[test]
6354    fn generated_shaders_use_simdgroup_features() {
6355        let shaders = generate_metal_shaders(&minimal_config());
6356
6357        assert!(
6358            shaders.contains("threadgroup_barrier"),
6359            "shaders should use threadgroup barriers"
6360        );
6361        assert!(
6362            shaders.contains("threadgroup float"),
6363            "shaders should use threadgroup shared memory"
6364        );
6365        assert!(
6366            shaders.contains("thread_index_in_threadgroup"),
6367            "shaders should use threadgroup indexing"
6368        );
6369        assert!(
6370            shaders.contains("simd_sum"),
6371            "shaders should use simd_sum for warp-level reduction"
6372        );
6373        assert!(
6374            shaders.contains("simd_max"),
6375            "attention kernel should use simd_max for cooperative softmax"
6376        );
6377        assert!(
6378            shaders.contains("thread_index_in_simdgroup"),
6379            "shaders should use simdgroup lane indexing"
6380        );
6381        assert!(
6382            shaders.contains("simdgroup_index_in_threadgroup"),
6383            "shaders should use simdgroup indexing within threadgroup"
6384        );
6385        assert!(
6386            shaders.contains("float4"),
6387            "matmul_vec should use float4 vectorized loads"
6388        );
6389    }
6390
6391    #[test]
6392    fn generated_main_rs_has_tokenizer_usage() {
6393        let config = minimal_config();
6394        let main_rs = generate_main_rs("test-model", &config).unwrap();
6395
6396        assert!(
6397            main_rs.contains("tokenizers::Tokenizer"),
6398            "main.rs should use tokenizers crate"
6399        );
6400        assert!(
6401            main_rs.contains("MetalModel::new"),
6402            "main.rs should call MetalModel::new"
6403        );
6404        assert!(
6405            main_rs.contains("model.forward"),
6406            "main.rs should call model.forward"
6407        );
6408        assert!(
6409            main_rs.contains("memmap2"),
6410            "main.rs should use memmap2 for zero-copy weight loading"
6411        );
6412    }
6413
6414    #[test]
6415    fn missing_config_returns_error() {
6416        let dir = tempfile::tempdir().unwrap();
6417        let graph = Graph::new("no-config");
6418        let result = generate_metal_project(&graph, dir.path(), "fail");
6419        assert!(
6420            matches!(result, Err(MetalCodegenError::MissingConfig)),
6421            "should fail with MissingConfig when graph has no config"
6422        );
6423    }
6424
6425    #[test]
6426    fn sanitize_name_works() {
6427        assert_eq!(sanitize_name("My Model!"), "my-model");
6428        assert_eq!(sanitize_name("test_model"), "test-model");
6429        assert_eq!(sanitize_name("simple"), "simple");
6430    }
6431
6432    #[test]
6433    fn generated_forward_uses_single_command_buffer() {
6434        let config = minimal_config();
6435        let model_rs = generate_model_rs(&config).unwrap();
6436
6437        // The forward function should create exactly one command buffer.
6438        // Use the exact signature to avoid matching forward_prefill/forward_profile.
6439        let forward_start = model_rs
6440            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6441            .unwrap();
6442        let forward_body = &model_rs[forward_start..];
6443        // End at the next pub/private method
6444        let forward_end = forward_body
6445            .find("\n    pub fn forward_profile")
6446            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
6447            .or_else(|| forward_body.find("\n    fn dispatch_"))
6448            .unwrap_or(forward_body.len());
6449        let forward_code = &forward_body[..forward_end];
6450
6451        // Should have exactly one new_command_buffer call
6452        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
6453        assert_eq!(
6454            cmd_buf_count, 1,
6455            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
6456        );
6457
6458        // Should have exactly one commit call
6459        let commit_count = forward_code.matches("cmd.commit()").count();
6460        assert_eq!(
6461            commit_count, 1,
6462            "forward() should commit exactly once, found {commit_count}"
6463        );
6464
6465        // Should wait: once for cmd + possibly once for prev_cmd drain
6466        let wait_count = forward_code.matches("wait_until_completed()").count();
6467        assert!(
6468            wait_count >= 1 && wait_count <= 2,
6469            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
6470        );
6471    }
6472
6473    #[test]
6474    fn generated_model_has_preallocated_working_buffers() {
6475        let config = minimal_config();
6476        let model_rs = generate_model_rs(&config).unwrap();
6477
6478        for buf_name in &[
6479            "normed_buf",
6480            "qkv_buf",
6481            "attn_out_buf",
6482            "attn_proj_buf",
6483            "gate_up_buf",
6484            "ffn_hidden_buf",
6485            "ffn_out_buf",
6486            "add_tmp_buf",
6487        ] {
6488            assert!(
6489                model_rs.contains(&format!("{buf_name}: Buffer")),
6490                "MetalModel should have pre-allocated {buf_name} field"
6491            );
6492        }
6493    }
6494
6495    #[test]
6496    fn generated_dispatch_helpers_take_compute_encoder_ref() {
6497        let config = minimal_config();
6498        let model_rs = generate_model_rs(&config).unwrap();
6499
6500        for method in &[
6501            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
6502            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
6503            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
6504            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
6505            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
6506            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
6507            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
6508            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
6509            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
6510            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
6511            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
6512            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
6513            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
6514        ] {
6515            assert!(
6516                model_rs.contains(method),
6517                "model.rs should contain dispatch helper: {method}"
6518            );
6519        }
6520    }
6521
6522    #[test]
6523    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
6524        let config = minimal_config();
6525        let model_rs = generate_model_rs(&config).unwrap();
6526
6527        // Find dispatch helpers section and check none create their own encoders
6528        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
6529        let helpers_code = &model_rs[helpers_start..];
6530
6531        // None of the dispatch_ helpers should call new_command_buffer
6532        assert!(
6533            !helpers_code.contains("self.queue.new_command_buffer()"),
6534            "dispatch helpers should not create their own command buffers"
6535        );
6536
6537        // None should create their own compute encoders
6538        assert!(
6539            !helpers_code.contains("new_compute_command_encoder()"),
6540            "dispatch helpers should not create their own compute encoders"
6541        );
6542
6543        // None should call end_encoding
6544        assert!(
6545            !helpers_code.contains("end_encoding()"),
6546            "dispatch helpers should not call end_encoding"
6547        );
6548
6549        // None should call commit or wait
6550        assert!(
6551            !helpers_code.contains(".commit()"),
6552            "dispatch helpers should not commit command buffers"
6553        );
6554        assert!(
6555            !helpers_code.contains("wait_until_completed"),
6556            "dispatch helpers should not wait on command buffers"
6557        );
6558    }
6559
6560    #[test]
6561    fn generated_forward_batches_compute_encoders() {
6562        let config = minimal_config();
6563        let model_rs = generate_model_rs(&config).unwrap();
6564
6565        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
6566        let forward_start = model_rs
6567            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6568            .unwrap();
6569        let forward_body = &model_rs[forward_start..];
6570        let forward_end = forward_body
6571            .find("\n    pub fn forward_profile")
6572            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
6573            .or_else(|| forward_body.find("\n    fn dispatch_"))
6574            .unwrap_or(forward_body.len());
6575        let forward_code = &forward_body[..forward_end];
6576
6577        // Forward should not allocate new buffers
6578        assert!(
6579            !forward_code.contains("device.new_buffer"),
6580            "forward() should not allocate new buffers per call"
6581        );
6582
6583        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
6584        // Copy operations use compute copy kernels instead of blit encoders.
6585        let compute_encoder_count = forward_code
6586            .matches("new_compute_command_encoder()")
6587            .count();
6588        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
6589
6590        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
6591        assert_eq!(
6592            compute_encoder_count, 1,
6593            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
6594        );
6595        assert_eq!(
6596            blit_encoder_count, 0,
6597            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
6598        );
6599    }
6600
6601    #[test]
6602    fn generated_forward_uses_add_inplace() {
6603        let config = minimal_config();
6604        let model_rs = generate_model_rs(&config).unwrap();
6605
6606        // Should use in-place add (no blit copy-back needed)
6607        assert!(
6608            model_rs.contains("dispatch_add_inplace"),
6609            "forward() should use dispatch_add_inplace for residual connections"
6610        );
6611        assert!(
6612            model_rs.contains("add_inplace_pipeline"),
6613            "MetalModel should have add_inplace_pipeline"
6614        );
6615    }
6616
6617    fn minimal_q8_config() -> ModelConfig {
6618        ModelConfig {
6619            architecture: Architecture::Llama,
6620            hidden_size: 64,
6621            intermediate_size: 128,
6622            num_layers: 2,
6623            num_attention_heads: 4,
6624            num_kv_heads: 4,
6625            head_dim: 16,
6626            vocab_size: 256,
6627            max_seq_len: 512,
6628            rms_norm_eps: 1e-5,
6629            rope_theta: 10000.0,
6630            dtype: DType::Q8_0,
6631            sliding_window_size: None,
6632            qkv_bias: false,
6633        }
6634    }
6635
6636    #[test]
6637    fn generated_shaders_contain_q8_kernel() {
6638        let shaders = generate_metal_shaders(&minimal_config());
6639
6640        assert!(
6641            shaders.contains("kernel void matmul_vec_q8"),
6642            "shaders should contain matmul_vec_q8 kernel"
6643        );
6644        assert!(
6645            shaders.contains("device const uchar* matrix"),
6646            "matmul_vec_q8 should accept raw Q8_0 bytes"
6647        );
6648        assert!(
6649            shaders.contains("packed_short4"),
6650            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
6651        );
6652        assert!(
6653            shaders.contains("as_type<char2>"),
6654            "matmul_vec_q8 should bitcast short lanes to char2"
6655        );
6656        assert!(
6657            shaders.contains("device const half*"),
6658            "matmul_vec_q8 should read f16 scale via half pointer"
6659        );
6660    }
6661
6662    #[test]
6663    fn generated_model_uses_fused_qkv_projections() {
6664        let config = minimal_config();
6665        let model_rs = generate_model_rs(&config).unwrap();
6666
6667        // Should have fused QKV weight in layer buffers
6668        assert!(
6669            model_rs.contains("qkv_weight: Buffer"),
6670            "LayerBuffers should have fused qkv_weight field"
6671        );
6672        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
6673        assert!(
6674            !model_rs.contains("    q_weight: Buffer"),
6675            "LayerBuffers should not have separate q_weight field"
6676        );
6677        assert!(
6678            !model_rs.contains("    k_weight: Buffer"),
6679            "LayerBuffers should not have separate k_weight field"
6680        );
6681        assert!(
6682            !model_rs.contains("    v_weight: Buffer"),
6683            "LayerBuffers should not have separate v_weight field"
6684        );
6685
6686        // Should have fused gate_up_weight
6687        assert!(
6688            model_rs.contains("gate_up_weight: Buffer"),
6689            "LayerBuffers should have fused gate_up_weight field"
6690        );
6691        // Should NOT have separate gate/up weight fields
6692        assert!(
6693            !model_rs.contains("    gate_weight: Buffer"),
6694            "LayerBuffers should not have separate gate_weight field"
6695        );
6696        assert!(
6697            !model_rs.contains("    up_weight: Buffer"),
6698            "LayerBuffers should not have separate up_weight field"
6699        );
6700
6701        // Should have fused working buffers
6702        assert!(
6703            model_rs.contains("qkv_buf: Buffer"),
6704            "MetalModel should have fused qkv_buf"
6705        );
6706        assert!(
6707            model_rs.contains("gate_up_buf: Buffer"),
6708            "MetalModel should have fused gate_up_buf"
6709        );
6710
6711        // Forward pass should use fused dispatch
6712        assert!(
6713            model_rs.contains("dispatch_silu_mul_fused"),
6714            "forward pass should use dispatch_silu_mul_fused"
6715        );
6716        assert!(
6717            model_rs.contains("dispatch_rope_offset"),
6718            "forward pass should use dispatch_rope_offset for fused QKV"
6719        );
6720        assert!(
6721            model_rs.contains("dispatch_attention_offset"),
6722            "forward pass should use dispatch_attention_offset for fused QKV"
6723        );
6724    }
6725
6726    #[test]
6727    fn q8_model_has_matmul_q8_pipeline() {
6728        let config = minimal_q8_config();
6729        let model_rs = generate_model_rs(&config).unwrap();
6730
6731        assert!(
6732            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
6733            "MetalModel should have matmul_q8_pipeline field"
6734        );
6735        assert!(
6736            model_rs.contains("matmul_q8_pipeline,"),
6737            "MetalModel Self should include matmul_q8_pipeline"
6738        );
6739    }
6740
6741    #[test]
6742    fn q8_model_uses_dispatch_matmul_q8() {
6743        let config = minimal_q8_config();
6744        let model_rs = generate_model_rs(&config).unwrap();
6745
6746        assert!(
6747            model_rs.contains("dispatch_matmul_q8"),
6748            "Q8_0 model should use dispatch_matmul_q8 for projections"
6749        );
6750        assert!(
6751            model_rs.contains("fn dispatch_matmul_q8"),
6752            "model.rs should define dispatch_matmul_q8 method"
6753        );
6754    }
6755
6756    #[test]
6757    fn q8_model_loads_raw_bytes_not_dequantized() {
6758        let config = minimal_q8_config();
6759        let model_rs = generate_model_rs(&config).unwrap();
6760
6761        // Should NOT contain dequantization code
6762        assert!(
6763            !model_rs.contains("f16_to_f32"),
6764            "Q8_0 model should not dequantize weights to f32"
6765        );
6766        assert!(
6767            !model_rs.contains("f32_data"),
6768            "Q8_0 model should not create f32 weight data"
6769        );
6770
6771        // Should load raw Q8_0 bytes directly
6772        assert!(
6773            model_rs.contains("total_raw as u64"),
6774            "Q8_0 model should load raw bytes into Metal buffer"
6775        );
6776    }
6777
6778    #[test]
6779    fn q8_model_norms_stay_f32() {
6780        let config = minimal_q8_config();
6781        let model_rs = generate_model_rs(&config).unwrap();
6782
6783        // Norm weights should still use f32 buffers
6784        assert!(
6785            model_rs.contains("let attn_norm = next_f32_buffer"),
6786            "attn_norm should use f32 buffer even for Q8_0 models"
6787        );
6788        assert!(
6789            model_rs.contains("let ffn_norm = next_f32_buffer"),
6790            "ffn_norm should use f32 buffer even for Q8_0 models"
6791        );
6792        assert!(
6793            model_rs.contains("let norm_buf = next_f32_buffer"),
6794            "final norm should use f32 buffer even for Q8_0 models"
6795        );
6796    }
6797
6798    #[test]
6799    fn q8_model_uses_fused_weight_loading() {
6800        let config = minimal_q8_config();
6801        let model_rs = generate_model_rs(&config).unwrap();
6802
6803        // Should use fused Q8 buffer loading for QKV
6804        assert!(
6805            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
6806            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
6807        );
6808        // Should use fused Q8 buffer loading for gate+up
6809        assert!(
6810            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
6811            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
6812        );
6813        // Should still use regular q8 buffer for individual weights
6814        assert!(
6815            model_rs.contains("let o_weight = next_q8_buffer"),
6816            "Q8_0 model should use next_q8_buffer for O weight"
6817        );
6818        assert!(
6819            model_rs.contains("let down_weight = next_q8_buffer"),
6820            "Q8_0 model should use next_q8_buffer for down weight"
6821        );
6822    }
6823
6824    #[test]
6825    fn f32_model_does_not_use_q8_dispatch() {
6826        let config = minimal_config();
6827        let model_rs = generate_model_rs(&config).unwrap();
6828
6829        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
6830        let forward_start = model_rs
6831            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6832            .unwrap();
6833        let forward_body = &model_rs[forward_start..];
6834        let forward_end = forward_body
6835            .find("\n    fn dispatch_")
6836            .unwrap_or(forward_body.len());
6837        let forward_code = &forward_body[..forward_end];
6838
6839        assert!(
6840            !forward_code.contains("dispatch_matmul_q8"),
6841            "f32 model forward should not use dispatch_matmul_q8"
6842        );
6843    }
6844
6845    #[test]
6846    fn q8_dispatch_helper_takes_compute_encoder_ref() {
6847        let config = minimal_q8_config();
6848        let model_rs = generate_model_rs(&config).unwrap();
6849
6850        assert!(
6851            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
6852            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
6853        );
6854    }
6855
6856    #[test]
6857    fn generated_model_has_double_buffered_prefill() {
6858        let config = minimal_config();
6859        let model_rs = generate_model_rs(&config).unwrap();
6860
6861        // MetalModel should have prev_cmd field for double-buffered prefill
6862        assert!(
6863            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
6864            "MetalModel should have prev_cmd field for double-buffered prefill"
6865        );
6866
6867        // Should have forward_prefill method
6868        assert!(
6869            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
6870            "MetalModel should have forward_prefill method"
6871        );
6872
6873        // forward() should drain prev_cmd at the start
6874        assert!(
6875            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
6876            "forward() should drain prev_cmd from previous prefill"
6877        );
6878    }
6879
6880    #[test]
6881    fn generated_main_rs_uses_forward_prefill_for_prompt() {
6882        let config = minimal_config();
6883        let main_rs = generate_main_rs("test-model", &config).unwrap();
6884
6885        assert!(
6886            main_rs.contains("forward_prefill"),
6887            "main.rs should use forward_prefill for intermediate prompt tokens"
6888        );
6889        assert!(
6890            main_rs.contains("double-buffered"),
6891            "main.rs should document double-buffered prefill"
6892        );
6893    }
6894
6895    #[test]
6896    fn generated_shaders_q8_uses_wide_vectorized_loads() {
6897        let shaders = generate_metal_shaders(&minimal_config());
6898
6899        assert!(
6900            shaders.contains("packed_short4"),
6901            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
6902        );
6903        assert!(
6904            shaders.contains("d0[0]"),
6905            "matmul_vec_q8 should index the wide pointer for row 0"
6906        );
6907        assert!(
6908            shaders.contains("as_type<char2>"),
6909            "matmul_vec_q8 should bitcast short lanes to char2"
6910        );
6911        assert!(
6912            shaders.contains("dot("),
6913            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
6914        );
6915    }
6916
6917    // ── Q4_0 tests ──────────────────────────────────────────────────────
6918
6919    fn minimal_q4_config() -> ModelConfig {
6920        ModelConfig {
6921            architecture: Architecture::Llama,
6922            hidden_size: 64,
6923            intermediate_size: 128,
6924            num_layers: 2,
6925            num_attention_heads: 4,
6926            num_kv_heads: 4,
6927            head_dim: 16,
6928            vocab_size: 256,
6929            max_seq_len: 512,
6930            rms_norm_eps: 1e-5,
6931            rope_theta: 10000.0,
6932            dtype: DType::Q4_0,
6933            sliding_window_size: None,
6934            qkv_bias: false,
6935        }
6936    }
6937
6938    #[test]
6939    fn generated_shaders_contain_q4_kernel() {
6940        let shaders = generate_metal_shaders(&minimal_config());
6941
6942        assert!(
6943            shaders.contains("kernel void matmul_vec_q4"),
6944            "shaders should contain matmul_vec_q4 kernel"
6945        );
6946        assert!(
6947            shaders.contains("Q4_ROWS_PER_TG"),
6948            "shaders should define Q4_ROWS_PER_TG constant"
6949        );
6950        assert!(
6951            shaders.contains("Q4_ROWS_PER_SG"),
6952            "shaders should define Q4_ROWS_PER_SG constant"
6953        );
6954    }
6955
6956    #[test]
6957    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
6958        let shaders = generate_metal_shaders(&minimal_config());
6959
6960        // Q4_0 kernel should use uchar4 for packed byte loads
6961        assert!(
6962            shaders.contains("uchar4"),
6963            "matmul_vec_q4 should use uchar4 for packed byte loads"
6964        );
6965        // Should unpack nibbles with &0xF and >>4
6966        assert!(
6967            shaders.contains("&0xF"),
6968            "matmul_vec_q4 should extract low nibble with &0xF"
6969        );
6970        assert!(
6971            shaders.contains(">>4"),
6972            "matmul_vec_q4 should extract high nibble with >>4"
6973        );
6974        // Should subtract 8 to convert unsigned to signed
6975        assert!(
6976            shaders.contains("-8)"),
6977            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
6978        );
6979        // Should use 18-byte block size
6980        assert!(
6981            shaders.contains("blk * 18"),
6982            "matmul_vec_q4 should use 18-byte block stride"
6983        );
6984    }
6985
6986    #[test]
6987    fn q4_model_has_matmul_q4_pipeline() {
6988        let config = minimal_q4_config();
6989        let model_rs = generate_model_rs(&config).unwrap();
6990
6991        assert!(
6992            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
6993            "MetalModel should have matmul_q4_pipeline field"
6994        );
6995        assert!(
6996            model_rs.contains("matmul_q4_pipeline,"),
6997            "MetalModel Self should include matmul_q4_pipeline"
6998        );
6999    }
7000
7001    #[test]
7002    fn q4_model_uses_dispatch_matmul_q4() {
7003        let config = minimal_q4_config();
7004        let model_rs = generate_model_rs(&config).unwrap();
7005
7006        assert!(
7007            model_rs.contains("dispatch_matmul_q4"),
7008            "Q4_0 model should use dispatch_matmul_q4 for projections"
7009        );
7010        assert!(
7011            model_rs.contains("fn dispatch_matmul_q4"),
7012            "model.rs should define dispatch_matmul_q4 method"
7013        );
7014    }
7015
7016    #[test]
7017    fn q4_model_loads_raw_bytes_not_dequantized() {
7018        let config = minimal_q4_config();
7019        let model_rs = generate_model_rs(&config).unwrap();
7020
7021        // Should NOT contain dequantization code
7022        assert!(
7023            !model_rs.contains("f16_to_f32"),
7024            "Q4_0 model should not dequantize weights to f32"
7025        );
7026
7027        // Should load raw Q4_0 bytes directly
7028        assert!(
7029            model_rs.contains("total_raw as u64"),
7030            "Q4_0 model should load raw bytes into Metal buffer"
7031        );
7032    }
7033
7034    #[test]
7035    fn q4_model_norms_stay_f32() {
7036        let config = minimal_q4_config();
7037        let model_rs = generate_model_rs(&config).unwrap();
7038
7039        assert!(
7040            model_rs.contains("let attn_norm = next_f32_buffer"),
7041            "attn_norm should use f32 buffer even for Q4_0 models"
7042        );
7043        assert!(
7044            model_rs.contains("let ffn_norm = next_f32_buffer"),
7045            "ffn_norm should use f32 buffer even for Q4_0 models"
7046        );
7047        assert!(
7048            model_rs.contains("let norm_buf = next_f32_buffer"),
7049            "final norm should use f32 buffer even for Q4_0 models"
7050        );
7051    }
7052
7053    #[test]
7054    fn q4_model_uses_fused_weight_loading() {
7055        let config = minimal_q4_config();
7056        let model_rs = generate_model_rs(&config).unwrap();
7057
7058        assert!(
7059            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7060            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7061        );
7062        assert!(
7063            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7064            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7065        );
7066        assert!(
7067            model_rs.contains("let o_weight = next_q4_buffer"),
7068            "Q4_0 model should use next_q4_buffer for O weight"
7069        );
7070        assert!(
7071            model_rs.contains("let down_weight = next_q4_buffer"),
7072            "Q4_0 model should use next_q4_buffer for down weight"
7073        );
7074    }
7075
7076    #[test]
7077    fn q4_dispatch_helper_takes_compute_encoder_ref() {
7078        let config = minimal_q4_config();
7079        let model_rs = generate_model_rs(&config).unwrap();
7080
7081        assert!(
7082            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
7083            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
7084        );
7085    }
7086
7087    #[test]
7088    fn f32_model_does_not_use_q4_dispatch() {
7089        let config = minimal_config();
7090        let model_rs = generate_model_rs(&config).unwrap();
7091
7092        let forward_start = model_rs
7093            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7094            .unwrap();
7095        let forward_body = &model_rs[forward_start..];
7096        let forward_end = forward_body
7097            .find("\n    fn dispatch_")
7098            .unwrap_or(forward_body.len());
7099        let forward_code = &forward_body[..forward_end];
7100
7101        assert!(
7102            !forward_code.contains("dispatch_matmul_q4"),
7103            "f32 model forward should not use dispatch_matmul_q4"
7104        );
7105    }
7106
7107    #[test]
7108    fn q4_model_lm_head_uses_q4_buffer() {
7109        let config = minimal_q4_config();
7110        let model_rs = generate_model_rs(&config).unwrap();
7111
7112        assert!(
7113            model_rs.contains("let lm_head_buf = next_q4_buffer"),
7114            "Q4_0 model should use next_q4_buffer for lm_head"
7115        );
7116    }
7117
7118    #[test]
7119    fn vec_tile_size_matches_model_dimensions() {
7120        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
7121        let small = minimal_config();
7122        let shaders_small = generate_metal_shaders(&small);
7123        assert!(
7124            shaders_small.contains("vec_tile[128]"),
7125            "vec_tile should be sized to max(hidden, intermediate) = 128"
7126        );
7127
7128        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
7129        let mut large = minimal_config();
7130        large.hidden_size = 2048;
7131        large.intermediate_size = 8192;
7132        let shaders_large = generate_metal_shaders(&large);
7133        assert!(
7134            shaders_large.contains("vec_tile[8192]"),
7135            "vec_tile should be 8192 for models with intermediate=8192"
7136        );
7137        assert!(
7138            !shaders_large.contains("vec_tile[4096]"),
7139            "vec_tile should NOT be hardcoded to 4096"
7140        );
7141    }
7142
7143    #[test]
7144    fn generated_cargo_toml_has_server_deps() {
7145        let toml = generate_cargo_toml("my-model");
7146        assert!(
7147            toml.contains("tiny_http"),
7148            "Cargo.toml should depend on tiny_http"
7149        );
7150        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
7151        assert!(
7152            toml.contains("serde_json"),
7153            "Cargo.toml should depend on serde_json"
7154        );
7155    }
7156
7157    #[test]
7158    fn generated_main_rs_has_serve_mode() {
7159        let config = minimal_config();
7160        let main_rs = generate_main_rs("test-model", &config).unwrap();
7161
7162        assert!(
7163            main_rs.contains("--serve"),
7164            "main.rs should parse --serve flag"
7165        );
7166        assert!(
7167            main_rs.contains("--port"),
7168            "main.rs should parse --port flag"
7169        );
7170        assert!(
7171            main_rs.contains("fn serve("),
7172            "main.rs should define serve function"
7173        );
7174        assert!(
7175            main_rs.contains("tiny_http::Server::http"),
7176            "main.rs should create tiny_http server"
7177        );
7178    }
7179
7180    #[test]
7181    fn generated_main_rs_has_chat_completions_endpoint() {
7182        let config = minimal_config();
7183        let main_rs = generate_main_rs("test-model", &config).unwrap();
7184
7185        assert!(
7186            main_rs.contains("/v1/chat/completions"),
7187            "main.rs should handle /v1/chat/completions endpoint"
7188        );
7189        assert!(
7190            main_rs.contains("/v1/models"),
7191            "main.rs should handle /v1/models endpoint"
7192        );
7193        assert!(
7194            main_rs.contains("/health"),
7195            "main.rs should handle /health endpoint"
7196        );
7197    }
7198
7199    #[test]
7200    fn generated_main_rs_has_sse_streaming() {
7201        let config = minimal_config();
7202        let main_rs = generate_main_rs("test-model", &config).unwrap();
7203
7204        assert!(
7205            main_rs.contains("text/event-stream"),
7206            "main.rs should set SSE content type for streaming"
7207        );
7208        assert!(
7209            main_rs.contains("chat.completion.chunk"),
7210            "main.rs should emit SSE chunks"
7211        );
7212        assert!(
7213            main_rs.contains("[DONE]"),
7214            "main.rs should emit [DONE] sentinel"
7215        );
7216    }
7217
7218    #[test]
7219    fn generated_main_rs_has_chat_message_formatting() {
7220        let config = minimal_config();
7221        let main_rs = generate_main_rs("test-model", &config).unwrap();
7222
7223        assert!(
7224            main_rs.contains("fn format_chat_messages"),
7225            "main.rs should define format_chat_messages function"
7226        );
7227        assert!(
7228            main_rs.contains("<|im_start|>"),
7229            "main.rs should use ChatML format"
7230        );
7231        assert!(
7232            main_rs.contains("<|im_end|>"),
7233            "main.rs should use ChatML format"
7234        );
7235    }
7236
7237    #[test]
7238    fn generated_main_rs_has_request_types() {
7239        let config = minimal_config();
7240        let main_rs = generate_main_rs("test-model", &config).unwrap();
7241
7242        assert!(
7243            main_rs.contains("struct ChatRequest"),
7244            "main.rs should define ChatRequest struct"
7245        );
7246        assert!(
7247            main_rs.contains("struct ChatMessage"),
7248            "main.rs should define ChatMessage struct"
7249        );
7250        assert!(
7251            main_rs.contains("Deserialize"),
7252            "main.rs should derive Deserialize for request types"
7253        );
7254    }
7255
7256    #[test]
7257    fn generated_model_has_reset_method() {
7258        let config = minimal_config();
7259        let model_rs = generate_model_rs(&config).unwrap();
7260
7261        assert!(
7262            model_rs.contains("pub fn reset(&mut self)"),
7263            "model.rs should have a reset() method for multi-request serving"
7264        );
7265        assert!(
7266            model_rs.contains("self.pos = 0"),
7267            "reset() should reset position to 0"
7268        );
7269    }
7270
7271    #[test]
7272    fn generated_main_rs_cli_mode_still_works() {
7273        let config = minimal_config();
7274        let main_rs = generate_main_rs("test-model", &config).unwrap();
7275
7276        // CLI mode should still be functional
7277        assert!(
7278            main_rs.contains("fn cli_mode("),
7279            "main.rs should define cli_mode function"
7280        );
7281        assert!(
7282            main_rs.contains("model.forward"),
7283            "main.rs should call model.forward"
7284        );
7285        assert!(
7286            main_rs.contains("model.forward_prefill"),
7287            "main.rs should call model.forward_prefill"
7288        );
7289    }
7290
7291    // ── Batched prefill tests ──────────────────────────────────────────
7292
7293    #[test]
7294    fn generated_shaders_contain_batch_kernels() {
7295        let shaders = generate_metal_shaders(&minimal_config());
7296
7297        assert!(
7298            shaders.contains("kernel void matmul_vec_batch"),
7299            "shaders should contain matmul_vec_batch kernel"
7300        );
7301        assert!(
7302            shaders.contains("kernel void matmul_vec_q8_batch"),
7303            "shaders should contain matmul_vec_q8_batch kernel"
7304        );
7305        assert!(
7306            shaders.contains("kernel void matmul_q8_gemm_batch"),
7307            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
7308        );
7309        assert!(
7310            shaders.contains("kernel void matmul_vec_q4_batch"),
7311            "shaders should contain matmul_vec_q4_batch kernel"
7312        );
7313        assert!(
7314            shaders.contains("kernel void rms_norm_batch"),
7315            "shaders should contain rms_norm_batch kernel"
7316        );
7317        assert!(
7318            shaders.contains("kernel void silu_mul_fused_batch"),
7319            "shaders should contain silu_mul_fused_batch kernel"
7320        );
7321        assert!(
7322            shaders.contains("kernel void add_inplace_batch"),
7323            "shaders should contain add_inplace_batch kernel"
7324        );
7325        assert!(
7326            shaders.contains("kernel void copy_embedding_batch"),
7327            "shaders should contain copy_embedding_batch kernel"
7328        );
7329    }
7330
7331    #[test]
7332    fn generated_model_has_batch_pipelines() {
7333        let config = minimal_config();
7334        let model_rs = generate_model_rs(&config).unwrap();
7335
7336        for pipeline in &[
7337            "matmul_batch_pipeline",
7338            "matmul_q8_batch_pipeline",
7339            "matmul_q8_gemm_batch_pipeline",
7340            "matmul_q4_batch_pipeline",
7341            "rms_norm_batch_pipeline",
7342            "rope_batch_pipeline",
7343            "silu_mul_fused_batch_pipeline",
7344            "add_inplace_batch_pipeline",
7345            "copy_embedding_batch_pipeline",
7346            "attention_batch_pipeline",
7347            "copy_kv_batch_pipeline",
7348        ] {
7349            assert!(
7350                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
7351                "MetalModel should have {pipeline} field"
7352            );
7353        }
7354    }
7355
7356    #[test]
7357    fn generated_model_has_batch_buffers() {
7358        let config = minimal_config();
7359        let model_rs = generate_model_rs(&config).unwrap();
7360
7361        for buf in &[
7362            "batch_hidden_buf",
7363            "batch_residual_buf",
7364            "batch_qkv_buf",
7365            "batch_attn_out_buf",
7366            "batch_attn_proj_buf",
7367            "batch_gate_up_buf",
7368            "batch_ffn_hidden_buf",
7369            "batch_ffn_out_buf",
7370            "batch_tokens_buf",
7371            "batch_positions_buf",
7372        ] {
7373            assert!(
7374                model_rs.contains(&format!("{buf}: Buffer")),
7375                "MetalModel should have {buf} field"
7376            );
7377        }
7378    }
7379
7380    #[test]
7381    fn generated_model_has_forward_prefill_batch() {
7382        let config = minimal_config();
7383        let model_rs = generate_model_rs(&config).unwrap();
7384
7385        assert!(
7386            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
7387            "MetalModel should have forward_prefill_batch method"
7388        );
7389
7390        // forward_prefill should delegate to forward_prefill_batch
7391        assert!(
7392            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
7393            "forward_prefill should delegate to forward_prefill_batch"
7394        );
7395    }
7396
7397    #[test]
7398    fn generated_model_has_max_batch_size_constant() {
7399        let config = minimal_config();
7400        let model_rs = generate_model_rs(&config).unwrap();
7401
7402        assert!(
7403            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
7404            "model.rs should define MAX_BATCH_SIZE constant"
7405        );
7406    }
7407
7408    #[test]
7409    fn forward_prefill_batch_uses_batch_dispatch() {
7410        let config = minimal_config();
7411        let model_rs = generate_model_rs(&config).unwrap();
7412
7413        let batch_start = model_rs
7414            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7415            .unwrap();
7416        let batch_body = &model_rs[batch_start..];
7417        let batch_end = batch_body
7418            .find("\n    pub fn reset")
7419            .unwrap_or(batch_body.len());
7420        let batch_code = &batch_body[..batch_end];
7421
7422        // Should use batched dispatch methods
7423        assert!(
7424            batch_code.contains("dispatch_rms_norm_batch"),
7425            "forward_prefill_batch should use dispatch_rms_norm_batch"
7426        );
7427        assert!(
7428            batch_code.contains("dispatch_copy_embedding_batch"),
7429            "forward_prefill_batch should use dispatch_copy_embedding_batch"
7430        );
7431        assert!(
7432            batch_code.contains("dispatch_silu_mul_fused_batch"),
7433            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
7434        );
7435        // Should use batched causal attention dispatch
7436        assert!(
7437            batch_code.contains("dispatch_attention_batch"),
7438            "forward_prefill_batch should use dispatch_attention_batch"
7439        );
7440        // Should use fused KV cache copy (both K and V in one dispatch)
7441        assert!(
7442            batch_code.contains("dispatch_copy_kv_both_batch"),
7443            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
7444        );
7445        // Should use fused RoPE Q+K dispatch
7446        assert!(
7447            batch_code.contains("dispatch_rope_qk_batch"),
7448            "forward_prefill_batch should use dispatch_rope_qk_batch"
7449        );
7450    }
7451
7452    #[test]
7453    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
7454        let config = minimal_q8_config();
7455        let model_rs = generate_model_rs(&config).unwrap();
7456
7457        let batch_start = model_rs
7458            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7459            .unwrap();
7460        let batch_body = &model_rs[batch_start..];
7461        let batch_end = batch_body
7462            .find("\n    pub fn reset")
7463            .unwrap_or(batch_body.len());
7464        let batch_code = &batch_body[..batch_end];
7465
7466        assert!(
7467            batch_code.contains("dispatch_matmul_q8_batch"),
7468            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
7469        );
7470    }
7471
7472    #[test]
7473    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
7474        let config = minimal_q4_config();
7475        let model_rs = generate_model_rs(&config).unwrap();
7476
7477        let batch_start = model_rs
7478            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7479            .unwrap();
7480        let batch_body = &model_rs[batch_start..];
7481        let batch_end = batch_body
7482            .find("\n    pub fn reset")
7483            .unwrap_or(batch_body.len());
7484        let batch_code = &batch_body[..batch_end];
7485
7486        assert!(
7487            batch_code.contains("dispatch_matmul_q4_batch"),
7488            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
7489        );
7490    }
7491
7492    #[test]
7493    fn generated_main_rs_uses_batched_prefill() {
7494        let config = minimal_config();
7495        let main_rs = generate_main_rs("test-model", &config).unwrap();
7496
7497        assert!(
7498            main_rs.contains("forward_prefill_batch"),
7499            "main.rs should use forward_prefill_batch for prompt tokens"
7500        );
7501    }
7502
7503    #[test]
7504    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
7505        let config = minimal_config();
7506        let model_rs = generate_model_rs(&config).unwrap();
7507
7508        let batch_start = model_rs
7509            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
7510            .unwrap();
7511        let batch_body = &model_rs[batch_start..];
7512        let batch_end = batch_body
7513            .find("\n    pub fn reset")
7514            .unwrap_or(batch_body.len());
7515        let batch_code = &batch_body[..batch_end];
7516
7517        assert!(
7518            batch_code.contains("dispatch_matmul_batch"),
7519            "f32 forward_prefill_batch should use dispatch_matmul_batch"
7520        );
7521        // Should NOT use Q8 or Q4 batch dispatch
7522        assert!(
7523            !batch_code.contains("dispatch_matmul_q8_batch"),
7524            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
7525        );
7526        assert!(
7527            !batch_code.contains("dispatch_matmul_q4_batch"),
7528            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
7529        );
7530    }
7531
7532    #[test]
7533    fn forward_uses_cpu_embedding_lookup() {
7534        let config = minimal_config();
7535        let model_rs = generate_model_rs(&config).unwrap();
7536
7537        // Find just the forward() body (not forward_profile)
7538        let forward_start = model_rs
7539            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7540            .unwrap();
7541        let forward_body = &model_rs[forward_start..];
7542        let forward_end = forward_body
7543            .find("\n    pub fn forward_profile")
7544            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7545            .unwrap_or(forward_body.len());
7546        let forward_code = &forward_body[..forward_end];
7547
7548        // forward() should use CPU memcpy for embedding lookup (unified memory)
7549        assert!(
7550            forward_code.contains("embed_buf.contents()"),
7551            "forward() should access embed_buf via CPU unified memory for embedding lookup"
7552        );
7553        assert!(
7554            forward_code.contains("copy_nonoverlapping"),
7555            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
7556        );
7557        // forward() should NOT use GPU dispatch for embedding
7558        assert!(
7559            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
7560            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
7561        );
7562    }
7563
7564    #[test]
7565    fn forward_profile_method_exists() {
7566        let config = minimal_config();
7567        let model_rs = generate_model_rs(&config).unwrap();
7568
7569        assert!(
7570            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
7571            "MetalModel should have forward_profile() method"
7572        );
7573        // Profile method should print timing information
7574        assert!(
7575            model_rs.contains("[profile]"),
7576            "forward_profile() should print timing with [profile] prefix"
7577        );
7578        assert!(
7579            model_rs.contains("d_embed"),
7580            "forward_profile() should measure embedding time"
7581        );
7582        assert!(
7583            model_rs.contains("d_layers"),
7584            "forward_profile() should measure layer time"
7585        );
7586        assert!(
7587            model_rs.contains("d_logits"),
7588            "forward_profile() should measure logits time"
7589        );
7590    }
7591
7592    #[test]
7593    fn generated_cli_has_profile_flag() {
7594        let config = minimal_config();
7595        let main_rs = generate_main_rs("test-model", &config).unwrap();
7596
7597        assert!(
7598            main_rs.contains("--profile"),
7599            "CLI should support --profile flag"
7600        );
7601        assert!(
7602            main_rs.contains("forward_profile"),
7603            "CLI should call forward_profile when --profile is set"
7604        );
7605    }
7606
7607    #[test]
7608    fn generated_cli_has_thermal_yield() {
7609        let config = minimal_config();
7610        let main_rs = generate_main_rs("test-model", &config).unwrap();
7611
7612        assert!(
7613            main_rs.contains("yield_now()"),
7614            "CLI generation loop should include thread::yield_now() for thermal management"
7615        );
7616    }
7617
7618    // ── Real-world validation tests ──────────────────────────────────────
7619
7620    #[test]
7621    fn generated_forward_handles_single_token_prompt() {
7622        // With a single token (the first prompt token), forward() should work
7623        // at pos=0 where seq_len=1. The attention kernel must handle the case
7624        // where there is only one KV entry (no prefill context).
7625        let config = minimal_config();
7626        let model_rs = generate_model_rs(&config).unwrap();
7627
7628        // The forward function should accept any u32 token_id (no minimum pos guard)
7629        let forward_start = model_rs
7630            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7631            .expect("forward() must exist");
7632        let forward_body = &model_rs[forward_start..forward_start + 400];
7633
7634        // Should NOT require pos > 0 or seq_len > 1
7635        assert!(
7636            !forward_body.contains("assert!(self.pos > 0"),
7637            "forward() must accept pos=0 (first token with no prefill)"
7638        );
7639
7640        // The attention kernel should handle seq_len=1 via the pos field
7641        assert!(
7642            model_rs.contains("self.pos"),
7643            "forward() should use self.pos to track sequence position"
7644        );
7645    }
7646
7647    #[test]
7648    fn generated_reset_clears_kv_cache_position() {
7649        // After reset(), the model should be in a clean state. The pos field
7650        // must be 0 so new generation starts from scratch.
7651        let config = minimal_config();
7652        let model_rs = generate_model_rs(&config).unwrap();
7653
7654        let reset_start = model_rs
7655            .find("pub fn reset(&mut self)")
7656            .expect("reset() must exist");
7657        let reset_body = &model_rs[reset_start..reset_start + 200];
7658
7659        // Reset must zero the position counter
7660        assert!(
7661            reset_body.contains("self.pos = 0"),
7662            "reset() must set self.pos = 0"
7663        );
7664
7665        // Verify reset clears prev_cmd (double-buffering state)
7666        assert!(
7667            reset_body.contains("self.prev_cmd = None"),
7668            "reset() should clear prev_cmd for clean command buffer state"
7669        );
7670    }
7671
7672    #[test]
7673    fn generated_serve_handles_empty_messages_gracefully() {
7674        // The serve endpoint should not crash when receiving an empty messages array.
7675        // The format_chat_messages function should handle this gracefully.
7676        let config = minimal_config();
7677        let main_rs = generate_main_rs("test-model", &config).unwrap();
7678
7679        // The format_chat_messages function should exist and handle empty input
7680        let format_fn_start = main_rs
7681            .find("fn format_chat_messages")
7682            .expect("format_chat_messages must exist");
7683        let format_fn_body =
7684            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
7685
7686        // It should iterate over messages (an empty slice produces an empty loop)
7687        assert!(
7688            format_fn_body.contains("for msg in messages"),
7689            "format_chat_messages should iterate over the messages slice"
7690        );
7691        // It should always append the assistant prompt suffix
7692        assert!(
7693            format_fn_body.contains("<|im_start|>assistant"),
7694            "format_chat_messages should always append assistant prompt header"
7695        );
7696
7697        // The serve function should call model.reset() before each request
7698        let serve_fn_start = main_rs
7699            .find("fn serve(")
7700            .expect("serve function must exist");
7701        let serve_fn_body = &main_rs[serve_fn_start..];
7702        assert!(
7703            serve_fn_body.contains("model.reset()"),
7704            "serve function should reset model between requests"
7705        );
7706    }
7707
7708    #[test]
7709    fn generated_model_forward_increments_pos() {
7710        // Each forward() call must increment self.pos so the next token
7711        // uses the correct RoPE position and KV cache offset.
7712        let config = minimal_config();
7713        let model_rs = generate_model_rs(&config).unwrap();
7714
7715        let forward_start = model_rs
7716            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7717            .unwrap();
7718        let forward_body = &model_rs[forward_start..];
7719        let forward_end = forward_body
7720            .find("\n    pub fn forward_profile")
7721            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7722            .or_else(|| forward_body.find("\n    fn dispatch_"))
7723            .unwrap_or(forward_body.len());
7724        let forward_code = &forward_body[..forward_end];
7725
7726        assert!(
7727            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
7728            "forward() must increment self.pos after processing a token"
7729        );
7730    }
7731}