Skip to main content

forgellm_codegen_metal/
lib.rs

1//! Forge Metal Code Generation — native Apple Silicon GPU inference.
2//!
3//! Generates a complete Cargo project that runs GPU inference via the `metal`
4//! crate (metal-rs), using Metal Shading Language compute kernels compiled at
5//! runtime. Targets Apple Silicon unified memory for zero-copy weight loading.
6
7use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13/// Errors during Metal code generation.
14#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16    /// The computation graph has no attached [`ModelConfig`].
17    #[error("graph has no model config")]
18    MissingConfig,
19
20    /// An I/O error during file creation.
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// A formatting error while building source strings.
25    #[error("format error: {0}")]
26    Fmt(#[from] std::fmt::Error),
27}
28
29/// Generate a complete Metal Cargo project from a computation graph.
30///
31/// Creates:
32/// - `Cargo.toml` — with metal, objc, tokenizers, memmap2, half dependencies
33/// - `src/main.rs` — CLI that reads weights + tokenizer, runs Metal inference
34/// - `src/model.rs` — MetalModel struct, compute pipelines, forward pass
35/// - `shaders/kernels.metal` — Metal Shading Language compute kernels
36pub fn generate_metal_project(
37    graph: &Graph,
38    output_dir: &Path,
39    model_name: &str,
40) -> Result<(), MetalCodegenError> {
41    let config = graph
42        .config
43        .as_ref()
44        .ok_or(MetalCodegenError::MissingConfig)?;
45
46    let src_dir = output_dir.join("src");
47    let shader_dir = output_dir.join("shaders");
48    fs::create_dir_all(&src_dir)?;
49    fs::create_dir_all(&shader_dir)?;
50
51    fs::write(
52        output_dir.join("Cargo.toml"),
53        generate_cargo_toml(model_name),
54    )?;
55
56    fs::write(
57        shader_dir.join("kernels.metal"),
58        generate_metal_shaders(config),
59    )?;
60
61    let model_rs = generate_model_rs(config)?;
62    fs::write(src_dir.join("model.rs"), model_rs)?;
63
64    let main_rs = generate_main_rs(model_name, config)?;
65    fs::write(src_dir.join("main.rs"), main_rs)?;
66
67    Ok(())
68}
69
70// ---------------------------------------------------------------------------
71// Internal helpers
72// ---------------------------------------------------------------------------
73
74fn sanitize_name(name: &str) -> String {
75    name.to_lowercase()
76        .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77        .trim_matches('-')
78        .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82    let sanitized = sanitize_name(model_name);
83    format!(
84        r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108    )
109}
110
111// ---------------------------------------------------------------------------
112// Metal Shading Language kernels
113// ---------------------------------------------------------------------------
114
115fn generate_metal_shaders(config: &ModelConfig) -> String {
116    // The vec_tile shared memory array must fit the largest column dimension
117    // used in any matmul kernel.  For standard LLM architectures this is
118    // max(hidden_size, intermediate_size).  Apple Silicon provides 32 KB of
119    // threadgroup memory; at 4 bytes per float that caps us at 8192 elements.
120    let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121    r#"//
122// Auto-generated by ForgeLLM Metal codegen.
123// Metal Shading Language compute kernels for transformer inference.
124//
125// Optimized with simdgroup cooperative reductions, shared memory vector
126// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
127// lane, and fast:: math intrinsics for Apple Silicon throughput.
128//
129
130#include <metal_stdlib>
131using namespace metal;
132
133// ── Constants ───────────────────────────────────────────────────────────
134// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
135// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
136// per shared memory load vs 4-row, improving ILP and reducing launches.
137constant constexpr uint SIMDGROUPS_PER_TG = 8;
138constant constexpr uint ROWS_PER_SIMDGROUP = 8;
139constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
140
141// ── matmul_vec ──────────────────────────────────────────────────────────
142// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
143// Uses simdgroup cooperative dot product with shared memory vector caching
144// and float4 vectorized loads. Each simdgroup processes 8 rows for better
145// shared memory reuse (8x vector reuse per load) and instruction-level
146// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
147kernel void matmul_vec(
148    device const float* matrix [[buffer(0)]],
149    device const float* vector [[buffer(1)]],
150    device float* output       [[buffer(2)]],
151    constant uint& rows        [[buffer(3)]],
152    constant uint& cols        [[buffer(4)]],
153    uint tgid [[threadgroup_position_in_grid]],
154    uint tid [[thread_index_in_threadgroup]],
155    uint simd_lane [[thread_index_in_simdgroup]],
156    uint simd_id [[simdgroup_index_in_threadgroup]])
157{
158    // Cooperatively load vector into threadgroup shared memory
159    threadgroup float vec_tile[VEC_TILE_SIZE];  // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
160    for (uint i = tid; i < cols; i += 256) {
161        vec_tile[i] = vector[i];
162    }
163    threadgroup_barrier(mem_flags::mem_threadgroup);
164
165    // Each simdgroup handles 8 consecutive rows
166    uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
167    if (row_base >= rows) return;
168
169    uint base0 = row_base * cols;
170    uint base1 = (row_base + 1) * cols;
171    uint base2 = (row_base + 2) * cols;
172    uint base3 = (row_base + 3) * cols;
173    uint base4 = (row_base + 4) * cols;
174    uint base5 = (row_base + 5) * cols;
175    uint base6 = (row_base + 6) * cols;
176    uint base7 = (row_base + 7) * cols;
177
178    // float4 vectorized accumulation across 8 rows
179    uint cols_vec4 = cols & ~127u;  // largest multiple of 128 <= cols
180    float4 sum4_0 = float4(0.0f);
181    float4 sum4_1 = float4(0.0f);
182    float4 sum4_2 = float4(0.0f);
183    float4 sum4_3 = float4(0.0f);
184    float4 sum4_4 = float4(0.0f);
185    float4 sum4_5 = float4(0.0f);
186    float4 sum4_6 = float4(0.0f);
187    float4 sum4_7 = float4(0.0f);
188
189    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
190        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
191        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
192        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
193        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
194        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
195        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
196        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
197        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
198        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
199    }
200
201    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
202    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
203    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
204    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
205    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
206    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
207    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
208    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
209
210    // Handle remaining elements (cols not divisible by 128)
211    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
212        float vv = vec_tile[j];
213        sum0 += matrix[base0 + j] * vv;
214        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
215        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
216        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
217        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
218        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
219        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
220        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
221    }
222
223    // Simdgroup hardware warp-level reduction
224    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
225    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
226    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
227    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
228
229    // Only first lane writes the results
230    if (simd_lane == 0) {
231        if (row_base     < rows) output[row_base]     = sum0;
232        if (row_base + 1 < rows) output[row_base + 1] = sum1;
233        if (row_base + 2 < rows) output[row_base + 2] = sum2;
234        if (row_base + 3 < rows) output[row_base + 3] = sum3;
235        if (row_base + 4 < rows) output[row_base + 4] = sum4;
236        if (row_base + 5 < rows) output[row_base + 5] = sum5;
237        if (row_base + 6 < rows) output[row_base + 6] = sum6;
238        if (row_base + 7 < rows) output[row_base + 7] = sum7;
239    }
240}
241
242// ── rms_norm ────────────────────────────────────────────────────────────
243// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
244// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
245// via shared memory for minimal synchronization overhead.
246kernel void rms_norm(
247    device const float* input   [[buffer(0)]],
248    device const float* weight  [[buffer(1)]],
249    device float* output        [[buffer(2)]],
250    constant uint& n            [[buffer(3)]],
251    constant float& eps         [[buffer(4)]],
252    uint tid [[thread_index_in_threadgroup]])
253{
254    // Each thread accumulates partial sum-of-squares
255    float sum_sq = 0.0f;
256    for (uint i = tid; i < n; i += 256) {
257        float v = input[i];
258        sum_sq += v * v;
259    }
260
261    // Simdgroup-level reduction (hardware warp sum)
262    sum_sq = simd_sum(sum_sq);
263
264    // Cross-simdgroup reduction via shared memory
265    threadgroup float shared[8];
266    uint simd_id = tid / 32;
267    uint simd_lane = tid % 32;
268    if (simd_lane == 0) {
269        shared[simd_id] = sum_sq;
270    }
271    threadgroup_barrier(mem_flags::mem_threadgroup);
272
273    // First thread computes final inverse RMS
274    if (tid == 0) {
275        float total = 0.0f;
276        for (uint i = 0; i < 8; i++) {
277            total += shared[i];
278        }
279        shared[0] = fast::rsqrt(total / float(n) + eps);
280    }
281    threadgroup_barrier(mem_flags::mem_threadgroup);
282
283    float inv_rms = shared[0];
284
285    // Normalize
286    for (uint i = tid; i < n; i += 256) {
287        output[i] = input[i] * inv_rms * weight[i];
288    }
289}
290
291// ── rope ────────────────────────────────────────────────────────────────
292// Rotary Position Embedding applied in-place.
293// Each thread handles one (head, pair) combination.
294kernel void rope(
295    device float* data        [[buffer(0)]],
296    constant uint& num_heads  [[buffer(1)]],
297    constant uint& head_dim   [[buffer(2)]],
298    constant uint& pos        [[buffer(3)]],
299    constant float& theta     [[buffer(4)]],
300    uint id [[thread_position_in_grid]])
301{
302    uint half_dim = head_dim / 2;
303    uint total_pairs = num_heads * half_dim;
304    if (id >= total_pairs) return;
305
306    uint h = id / half_dim;
307    uint i = id % half_dim;
308    uint off = h * head_dim;
309
310    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
311    float angle = float(pos) * freq;
312    float c = cos(angle);
313    float s = sin(angle);
314
315    float x0 = data[off + 2 * i];
316    float x1 = data[off + 2 * i + 1];
317    data[off + 2 * i]     = x0 * c - x1 * s;
318    data[off + 2 * i + 1] = x0 * s + x1 * c;
319}
320
321// ── softmax ─────────────────────────────────────────────────────────────
322// Numerically stable softmax over a 1-D array.
323// Single-threadgroup kernel with cooperative reduction.
324kernel void softmax(
325    device float* data       [[buffer(0)]],
326    constant uint& n         [[buffer(1)]],
327    uint tid [[thread_index_in_threadgroup]],
328    uint tg_size [[threads_per_threadgroup]])
329{
330    threadgroup float shared_val[256];
331
332    // Pass 1: find max
333    float local_max = -INFINITY;
334    for (uint i = tid; i < n; i += tg_size) {
335        local_max = max(local_max, data[i]);
336    }
337    shared_val[tid] = local_max;
338    threadgroup_barrier(mem_flags::mem_threadgroup);
339
340    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
341        if (tid < stride) {
342            shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
343        }
344        threadgroup_barrier(mem_flags::mem_threadgroup);
345    }
346    float max_val = shared_val[0];
347    threadgroup_barrier(mem_flags::mem_threadgroup);
348
349    // Pass 2: exp and sum
350    float local_sum = 0.0f;
351    for (uint i = tid; i < n; i += tg_size) {
352        float e = fast::exp(data[i] - max_val);
353        data[i] = e;
354        local_sum += e;
355    }
356    shared_val[tid] = local_sum;
357    threadgroup_barrier(mem_flags::mem_threadgroup);
358
359    for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
360        if (tid < stride) {
361            shared_val[tid] += shared_val[tid + stride];
362        }
363        threadgroup_barrier(mem_flags::mem_threadgroup);
364    }
365    float inv_sum = 1.0f / shared_val[0];
366    threadgroup_barrier(mem_flags::mem_threadgroup);
367
368    // Pass 3: normalize
369    for (uint i = tid; i < n; i += tg_size) {
370        data[i] *= inv_sum;
371    }
372}
373
374// ── silu_mul ────────────────────────────────────────────────────────────
375// Fused SiLU activation * element-wise multiply:
376//   output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
377kernel void silu_mul(
378    device const float* gate [[buffer(0)]],
379    device const float* up   [[buffer(1)]],
380    device float* output     [[buffer(2)]],
381    constant uint& n         [[buffer(3)]],
382    uint id [[thread_position_in_grid]])
383{
384    if (id >= n) return;
385    float g = gate[id];
386    output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
387}
388
389// ── silu_mul_fused ─────────────────────────────────────────────────────
390// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
391//   gate = gate_up[0..n], up = gate_up[n..2*n]
392//   output[i] = silu(gate_up[i]) * gate_up[n + i]
393kernel void silu_mul_fused(
394    device const float* gate_up [[buffer(0)]],
395    device float* output        [[buffer(1)]],
396    constant uint& n            [[buffer(2)]],
397    uint id [[thread_position_in_grid]])
398{
399    if (id >= n) return;
400    float g = gate_up[id];
401    float u = gate_up[n + id];
402    output[id] = (g / (1.0f + fast::exp(-g))) * u;
403}
404
405// ── elementwise_add ─────────────────────────────────────────────────────
406// Residual connection: output[i] = a[i] + b[i]
407kernel void elementwise_add(
408    device const float* a  [[buffer(0)]],
409    device const float* b  [[buffer(1)]],
410    device float* output   [[buffer(2)]],
411    constant uint& n       [[buffer(3)]],
412    uint id [[thread_position_in_grid]])
413{
414    if (id >= n) return;
415    output[id] = a[id] + b[id];
416}
417
418// ── copy_buffer ─────────────────────────────────────────────────────────
419// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
420// transitions. Used for KV cache updates and embedding lookup.
421kernel void copy_buffer(
422    device const float* src [[buffer(0)]],
423    device float* dst       [[buffer(1)]],
424    constant uint& count    [[buffer(2)]],
425    uint id [[thread_position_in_grid]])
426{
427    if (id < count) dst[id] = src[id];
428}
429
430// ── copy_offset ─────────────────────────────────────────────────────────
431// Copy with source offset (in floats). Used for embedding table lookup
432// where we need to copy a specific row from a large table.
433kernel void copy_offset(
434    device const float* src     [[buffer(0)]],
435    device float* dst           [[buffer(1)]],
436    constant uint& src_offset   [[buffer(2)]],  // in floats
437    constant uint& count        [[buffer(3)]],
438    uint id [[thread_position_in_grid]])
439{
440    if (id < count) dst[id] = src[src_offset + id];
441}
442
443// ── add_inplace ─────────────────────────────────────────────────────────
444// In-place residual connection: a[i] += b[i]
445// Avoids a separate blit copy for residual add, reducing encoder overhead.
446kernel void add_inplace(
447    device float* a        [[buffer(0)]],
448    device const float* b  [[buffer(1)]],
449    constant uint& n       [[buffer(2)]],
450    uint id [[thread_position_in_grid]])
451{
452    if (id >= n) return;
453    a[id] += b[id];
454}
455
456// ── matmul_vec_q8 ─────────────────────────────────────────────────────
457// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
458// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
459// Operates directly on quantized weights to halve memory bandwidth vs f32,
460// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
461//
462// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
463// because int8->float conversion doubles register demand.  Fully unrolled
464// inner loop with float4 vector loads from shared memory eliminates loop
465// overhead and enables better instruction scheduling.
466// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
467constant constexpr uint Q8_ROWS_PER_SG = 4;
468constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
469
470// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
471// doubles ALU work, so the same register budget applies).
472constant constexpr uint Q4_ROWS_PER_SG = 4;
473constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
474
475kernel void matmul_vec_q8(
476    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes
477    device const float* vector   [[buffer(1)]],  // f32 input
478    device float* output         [[buffer(2)]],
479    constant uint& rows          [[buffer(3)]],
480    constant uint& cols          [[buffer(4)]],  // number of elements per row
481    uint tgid [[threadgroup_position_in_grid]],
482    uint tid [[thread_index_in_threadgroup]],
483    uint simd_lane [[thread_index_in_simdgroup]],
484    uint simd_id [[simdgroup_index_in_threadgroup]])
485{
486    // Load vector into shared memory
487    threadgroup float vec_tile[VEC_TILE_SIZE];
488    for (uint i = tid; i < cols; i += 256) {
489        vec_tile[i] = vector[i];
490    }
491    threadgroup_barrier(mem_flags::mem_threadgroup);
492
493    // Each simdgroup handles 4 consecutive rows (lower register pressure)
494    uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
495    if (row_base >= rows) return;
496
497    // Q8_0: each block is 34 bytes for 32 elements
498    uint blocks_per_row = cols / 32;
499    uint row_bytes = blocks_per_row * 34;
500
501    // Pointers to each row's Q8_0 data
502    device const uchar* r0 = matrix + row_base * row_bytes;
503    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
504    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
505    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
506
507    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
508
509    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
510        uint bb = blk * 34;
511        uint vb = blk * 32;
512
513        // Prefetch all 4 scales
514        float sc0 = float(*(device const half*)(r0 + bb));
515        float sc1 = float(*(device const half*)(r1 + bb));
516        float sc2 = float(*(device const half*)(r2 + bb));
517        float sc3 = float(*(device const half*)(r3 + bb));
518
519        // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
520        // Q8_0 block layout where the int8 data starts at offset +2 from a
521        // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
522        // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
523        // reduction in memory transactions. Metal's char16/packed_char16 are
524        // reserved types and packed_*int4 require >=4-byte alignment which
525        // this layout does not provide, so packed_short4 is the widest valid
526        // vectorized load.
527        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
528        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
529        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
530        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
531
532        // Load all 8 float4 vector values for this 32-element block from shared memory
533        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
534        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
535        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
536        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
537        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
538        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
539        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
540        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
541
542        // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
543        // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
544        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
545            short4 _s = short4(SHORT4); \
546            char2 _a = as_type<char2>(_s.x); \
547            char2 _b = as_type<char2>(_s.y); \
548            char2 _c = as_type<char2>(_s.z); \
549            char2 _d = as_type<char2>(_s.w); \
550            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
551            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
552        }
553
554        float4 f0, f1;
555        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
556
557        // Row 0: 4 short4 loads cover 32 int8 weights
558        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
559        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
560        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
561        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
562
563        // Row 1
564        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
565        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
566        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
567        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
568
569        // Row 2
570        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
571        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
572        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
573        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
574
575        // Row 3
576        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
577        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
578        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
579        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
580
581        #undef Q8_UNPACK8
582
583        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
584    }
585
586    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
587    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
588
589    if (simd_lane == 0) {
590        if (row_base     < rows) output[row_base]     = sum0;
591        if (row_base + 1 < rows) output[row_base + 1] = sum1;
592        if (row_base + 2 < rows) output[row_base + 2] = sum2;
593        if (row_base + 3 < rows) output[row_base + 3] = sum3;
594    }
595}
596
597// ── matmul_vec_q4 ─────────────────────────────────────────────────────
598// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
599// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
600// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
601// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
602//
603// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
604// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
605kernel void matmul_vec_q4(
606    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes
607    device const float* vector   [[buffer(1)]],  // f32 input
608    device float* output         [[buffer(2)]],
609    constant uint& rows          [[buffer(3)]],
610    constant uint& cols          [[buffer(4)]],  // number of elements per row
611    uint tgid [[threadgroup_position_in_grid]],
612    uint tid [[thread_index_in_threadgroup]],
613    uint simd_lane [[thread_index_in_simdgroup]],
614    uint simd_id [[simdgroup_index_in_threadgroup]])
615{
616    // Load vector into shared memory
617    threadgroup float vec_tile[VEC_TILE_SIZE];
618    for (uint i = tid; i < cols; i += 256) {
619        vec_tile[i] = vector[i];
620    }
621    threadgroup_barrier(mem_flags::mem_threadgroup);
622
623    // Each simdgroup handles 4 consecutive rows
624    uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
625    if (row_base >= rows) return;
626
627    // Q4_0: each block is 18 bytes for 32 elements
628    uint blocks_per_row = cols / 32;
629    uint row_bytes = blocks_per_row * 18;
630
631    // Pointers to each row's Q4_0 data
632    device const uchar* r0 = matrix + row_base * row_bytes;
633    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
634    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
635    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
636
637    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
638
639    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
640        uint bb = blk * 18;
641        uint vb = blk * 32;
642
643        // Prefetch all 4 scales
644        float sc0 = float(*(device const half*)(r0 + bb));
645        float sc1 = float(*(device const half*)(r1 + bb));
646        float sc2 = float(*(device const half*)(r2 + bb));
647        float sc3 = float(*(device const half*)(r3 + bb));
648
649        // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
650        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
651        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
652        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
653        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
654
655        // Load 8 float4 vector values for 32 elements from shared memory
656        // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
657        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);       // [0..3]
658        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);   // [4..7]
659        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);   // [8..11]
660        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);  // [12..15]
661        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);  // [16..19]
662        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);  // [20..23]
663        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);  // [24..27]
664        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);  // [28..31]
665
666        // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
667        // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
668        float bd0=0, bd1=0, bd2=0, bd3=0;
669        uchar4 b;
670
671        // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
672        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
673                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
674                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
675                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
676        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
677                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
678                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
679                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
680        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
681                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
682                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
683                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
684        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
685                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
686                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
687                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
688
689        // Row 1
690        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
691                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
692                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
693                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
694        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
695                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
696                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
697                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
698        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
699                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
700                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
701                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
702        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
703                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
704                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
705                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
706
707        // Row 2
708        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
709                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
710                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
711                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
712        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
713                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
714                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
715                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
716        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
717                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
718                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
719                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
720        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
721                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
722                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
723                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
724
725        // Row 3
726        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
727                    +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
728                    +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
729                    +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
730        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
731                    +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
732                    +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
733                    +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
734        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
735                    +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
736                    +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
737                    +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
738        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
739                    +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
740                    +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
741                    +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
742
743        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
744    }
745
746    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
747    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
748
749    if (simd_lane == 0) {
750        if (row_base     < rows) output[row_base]     = sum0;
751        if (row_base + 1 < rows) output[row_base + 1] = sum1;
752        if (row_base + 2 < rows) output[row_base + 2] = sum2;
753        if (row_base + 3 < rows) output[row_base + 3] = sum3;
754    }
755}
756
757// ── attention ───────────────────────────────────────────────────────────
758// Single-query attention with simdgroup cooperative reductions.
759// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
760// with simd_max/simd_sum reductions, then weighted sum of V.
761// Each threadgroup handles one head with 256 threads (8 simdgroups).
762//
763// Buffers:
764//   q:       [num_heads * head_dim]       current query
765//   k_cache: [max_seq_len * num_kv_heads * head_dim]
766//   v_cache: [max_seq_len * num_kv_heads * head_dim]
767//   output:  [num_heads * head_dim]
768kernel void attention(
769    device const float* q        [[buffer(0)]],
770    device const float* k_cache  [[buffer(1)]],
771    device const float* v_cache  [[buffer(2)]],
772    device float* output         [[buffer(3)]],
773    constant uint& seq_len       [[buffer(4)]],
774    constant uint& num_heads     [[buffer(5)]],
775    constant uint& num_kv_heads  [[buffer(6)]],
776    constant uint& head_dim      [[buffer(7)]],
777    uint tgid [[threadgroup_position_in_grid]],
778    uint tid [[thread_index_in_threadgroup]],
779    uint simd_lane [[thread_index_in_simdgroup]],
780    uint simd_id [[simdgroup_index_in_threadgroup]])
781{
782    uint head = tgid;
783    if (head >= num_heads) return;
784    uint kv_head = head / (num_heads / num_kv_heads);
785
786    uint q_off = head * head_dim;
787
788    // Step 1: Compute attention scores Q·K^T with simdgroup reduction
789    // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
790    // vs 4096. For seq_len > 2048, generation-phase attention is rare;
791    // most generation steps have short effective context.
792    threadgroup float scores[2048];  // max seq_len for generation phase
793
794    for (uint s = simd_id; s < seq_len; s += 8) {
795        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
796        float dot = 0.0;
797        for (uint d = simd_lane; d < head_dim; d += 32) {
798            dot += q[q_off + d] * k_cache[k_off + d];
799        }
800        dot = simd_sum(dot);
801        if (simd_lane == 0) {
802            scores[s] = dot * fast::rsqrt(float(head_dim));
803        }
804    }
805    threadgroup_barrier(mem_flags::mem_threadgroup);
806
807    // Step 2: Softmax over scores (cooperative)
808    // Find max
809    float local_max = -INFINITY;
810    for (uint s = tid; s < seq_len; s += 256) {
811        local_max = max(local_max, scores[s]);
812    }
813    local_max = simd_max(local_max);
814    threadgroup float shared_max[8];
815    if (simd_lane == 0) shared_max[simd_id] = local_max;
816    threadgroup_barrier(mem_flags::mem_threadgroup);
817    if (tid == 0) {
818        float m = shared_max[0];
819        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
820        shared_max[0] = m;
821    }
822    threadgroup_barrier(mem_flags::mem_threadgroup);
823    float max_val = shared_max[0];
824
825    // Exp and sum
826    float local_sum = 0.0;
827    for (uint s = tid; s < seq_len; s += 256) {
828        scores[s] = fast::exp(scores[s] - max_val);
829        local_sum += scores[s];
830    }
831    local_sum = simd_sum(local_sum);
832    threadgroup float shared_sum[8];
833    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
834    threadgroup_barrier(mem_flags::mem_threadgroup);
835    if (tid == 0) {
836        float total = 0.0;
837        for (uint i = 0; i < 8; i++) total += shared_sum[i];
838        shared_sum[0] = 1.0 / total;
839    }
840    threadgroup_barrier(mem_flags::mem_threadgroup);
841    float inv_sum = shared_sum[0];
842
843    for (uint s = tid; s < seq_len; s += 256) {
844        scores[s] *= inv_sum;
845    }
846    threadgroup_barrier(mem_flags::mem_threadgroup);
847
848    // Step 3: Weighted sum of V: output = scores · V
849    // Each thread handles a range of head_dim dimensions.
850    // Process 4 sequence positions at a time for better ILP and reduced
851    // loop overhead (float4 score gather, 4 V loads per iteration).
852    uint seq_len4 = seq_len & ~3u;  // largest multiple of 4 <= seq_len
853    uint v_stride = num_kv_heads * head_dim;
854    for (uint d = tid; d < head_dim; d += 256) {
855        float acc = 0.0;
856        uint v_base = kv_head * head_dim + d;
857        for (uint s = 0; s < seq_len4; s += 4) {
858            float sc0 = scores[s];
859            float sc1 = scores[s + 1];
860            float sc2 = scores[s + 2];
861            float sc3 = scores[s + 3];
862            acc += sc0 * v_cache[s * v_stride + v_base]
863                 + sc1 * v_cache[(s+1) * v_stride + v_base]
864                 + sc2 * v_cache[(s+2) * v_stride + v_base]
865                 + sc3 * v_cache[(s+3) * v_stride + v_base];
866        }
867        for (uint s = seq_len4; s < seq_len; s++) {
868            acc += scores[s] * v_cache[s * v_stride + v_base];
869        }
870        output[q_off + d] = acc;
871    }
872}
873
874// ── Batched prefill kernels ────────────────────────────────────────────
875// These kernels process M input vectors against the same weight matrix
876// in a single dispatch, converting mat-vec into mat-mat for better GPU
877// utilization during prompt prefill.
878
879// ── rms_norm_batch ─────────────────────────────────────────────────────
880// RMS normalization for a batch of vectors.
881// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
882// Grid: M threadgroups (one per token).
883kernel void rms_norm_batch(
884    device const float* input   [[buffer(0)]],  // [M, n]
885    device const float* weight  [[buffer(1)]],  // [n]
886    device float* output        [[buffer(2)]],  // [M, n]
887    constant uint& n            [[buffer(3)]],
888    constant float& eps         [[buffer(4)]],
889    constant uint& num_tokens   [[buffer(5)]],
890    uint tgid [[threadgroup_position_in_grid]],
891    uint tid [[thread_index_in_threadgroup]])
892{
893    if (tgid >= num_tokens) return;
894
895    uint base = tgid * n;
896
897    float sum_sq = 0.0f;
898    for (uint i = tid; i < n; i += 256) {
899        float v = input[base + i];
900        sum_sq += v * v;
901    }
902
903    sum_sq = simd_sum(sum_sq);
904
905    threadgroup float shared[8];
906    uint simd_id = tid / 32;
907    uint simd_lane = tid % 32;
908    if (simd_lane == 0) {
909        shared[simd_id] = sum_sq;
910    }
911    threadgroup_barrier(mem_flags::mem_threadgroup);
912
913    if (tid == 0) {
914        float total = 0.0f;
915        for (uint i = 0; i < 8; i++) {
916            total += shared[i];
917        }
918        shared[0] = fast::rsqrt(total / float(n) + eps);
919    }
920    threadgroup_barrier(mem_flags::mem_threadgroup);
921
922    float inv_rms = shared[0];
923
924    for (uint i = tid; i < n; i += 256) {
925        output[base + i] = input[base + i] * inv_rms * weight[i];
926    }
927}
928
929// ── rope_batch ─────────────────────────────────────────────────────────
930// Rotary Position Embedding for a batch of vectors with different positions.
931// data layout: [M, num_heads * head_dim], positions: [M]
932// Each thread handles one (token, head, pair) combination.
933kernel void rope_batch(
934    device float* data           [[buffer(0)]],  // [M, num_heads * head_dim]
935    constant uint& num_heads     [[buffer(1)]],
936    constant uint& head_dim      [[buffer(2)]],
937    device const uint* positions  [[buffer(3)]],  // [M] position per token
938    constant float& theta        [[buffer(4)]],
939    constant uint& num_tokens    [[buffer(5)]],
940    uint id [[thread_position_in_grid]])
941{
942    uint half_dim = head_dim / 2;
943    uint pairs_per_token = num_heads * half_dim;
944    uint total = num_tokens * pairs_per_token;
945    if (id >= total) return;
946
947    uint token = id / pairs_per_token;
948    uint rem = id % pairs_per_token;
949    uint h = rem / half_dim;
950    uint i = rem % half_dim;
951    uint off = token * (num_heads * head_dim) + h * head_dim;
952
953    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
954    float angle = float(positions[token]) * freq;
955    float c = cos(angle);
956    float s = sin(angle);
957
958    float x0 = data[off + 2 * i];
959    float x1 = data[off + 2 * i + 1];
960    data[off + 2 * i]     = x0 * c - x1 * s;
961    data[off + 2 * i + 1] = x0 * s + x1 * c;
962}
963
964// ── silu_mul_fused_batch ───────────────────────────────────────────────
965// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
966// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
967kernel void silu_mul_fused_batch(
968    device const float* gate_up [[buffer(0)]],  // [M, 2*n]
969    device float* output        [[buffer(1)]],  // [M, n]
970    constant uint& n            [[buffer(2)]],
971    constant uint& num_tokens   [[buffer(3)]],
972    uint id [[thread_position_in_grid]])
973{
974    uint total = num_tokens * n;
975    if (id >= total) return;
976    uint token = id / n;
977    uint i = id % n;
978    uint gu_base = token * 2 * n;
979    float g = gate_up[gu_base + i];
980    float u = gate_up[gu_base + n + i];
981    output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
982}
983
984// ── add_inplace_batch ──────────────────────────────────────────────────
985// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
986kernel void add_inplace_batch(
987    device float* a        [[buffer(0)]],  // [M * n]
988    device const float* b  [[buffer(1)]],  // [M * n]
989    constant uint& total   [[buffer(2)]],  // M * n
990    uint id [[thread_position_in_grid]])
991{
992    if (id >= total) return;
993    a[id] += b[id];
994}
995
996// ── copy_embedding_batch ───────────────────────────────────────────────
997// Copy M embedding rows from embedding table to a contiguous batch buffer.
998// tokens: [M] array of token IDs, each selects a row of `dim` floats.
999kernel void copy_embedding_batch(
1000    device const float* embed   [[buffer(0)]],  // [vocab_size, dim]
1001    device float* output        [[buffer(1)]],  // [M, dim]
1002    device const uint* tokens   [[buffer(2)]],  // [M]
1003    constant uint& dim          [[buffer(3)]],
1004    constant uint& num_tokens   [[buffer(4)]],
1005    uint id [[thread_position_in_grid]])
1006{
1007    uint total = num_tokens * dim;
1008    if (id >= total) return;
1009    uint token_idx = id / dim;
1010    uint d = id % dim;
1011    output[id] = embed[tokens[token_idx] * dim + d];
1012}
1013
1014// ── matmul_vec_batch ───────────────────────────────────────────────────
1015// Batched matrix-vector multiply: process M input vectors against the same
1016// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1017// Each threadgroup handles one (token, row_group) pair.
1018kernel void matmul_vec_batch(
1019    device const float* matrix  [[buffer(0)]],  // [rows, cols] weight
1020    device const float* inputs  [[buffer(1)]],  // [M, cols] input batch
1021    device float* outputs       [[buffer(2)]],  // [M, rows] output batch
1022    constant uint& num_tokens   [[buffer(3)]],  // M
1023    constant uint& rows         [[buffer(4)]],
1024    constant uint& cols         [[buffer(5)]],
1025    uint tgid [[threadgroup_position_in_grid]],
1026    uint tid [[thread_index_in_threadgroup]],
1027    uint simd_lane [[thread_index_in_simdgroup]],
1028    uint simd_id [[simdgroup_index_in_threadgroup]])
1029{
1030    uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1031    uint token = tgid / row_tgs;
1032    uint tg_in_token = tgid % row_tgs;
1033    if (token >= num_tokens) return;
1034
1035    // Load this token's input vector into shared memory
1036    threadgroup float vec_tile[VEC_TILE_SIZE];
1037    device const float* input = inputs + token * cols;
1038    for (uint i = tid; i < cols; i += 256) {
1039        vec_tile[i] = input[i];
1040    }
1041    threadgroup_barrier(mem_flags::mem_threadgroup);
1042
1043    uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1044    if (row_base >= rows) return;
1045
1046    uint base0 = row_base * cols;
1047    uint base1 = (row_base + 1) * cols;
1048    uint base2 = (row_base + 2) * cols;
1049    uint base3 = (row_base + 3) * cols;
1050    uint base4 = (row_base + 4) * cols;
1051    uint base5 = (row_base + 5) * cols;
1052    uint base6 = (row_base + 6) * cols;
1053    uint base7 = (row_base + 7) * cols;
1054
1055    uint cols_vec4 = cols & ~127u;
1056    float4 sum4_0 = float4(0.0f);
1057    float4 sum4_1 = float4(0.0f);
1058    float4 sum4_2 = float4(0.0f);
1059    float4 sum4_3 = float4(0.0f);
1060    float4 sum4_4 = float4(0.0f);
1061    float4 sum4_5 = float4(0.0f);
1062    float4 sum4_6 = float4(0.0f);
1063    float4 sum4_7 = float4(0.0f);
1064
1065    for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1066        float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1067        sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1068        if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1069        if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1070        if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1071        if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1072        if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1073        if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1074        if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1075    }
1076
1077    float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1078    float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1079    float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1080    float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1081    float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1082    float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1083    float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1084    float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1085
1086    for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1087        float vv = vec_tile[j];
1088        sum0 += matrix[base0 + j] * vv;
1089        if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1090        if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1091        if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1092        if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1093        if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1094        if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1095        if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1096    }
1097
1098    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1099    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1100    sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1101    sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1102
1103    device float* output = outputs + token * rows;
1104    if (simd_lane == 0) {
1105        if (row_base     < rows) output[row_base]     = sum0;
1106        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1107        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1108        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1109        if (row_base + 4 < rows) output[row_base + 4] = sum4;
1110        if (row_base + 5 < rows) output[row_base + 5] = sum5;
1111        if (row_base + 6 < rows) output[row_base + 6] = sum6;
1112        if (row_base + 7 < rows) output[row_base + 7] = sum7;
1113    }
1114}
1115
1116// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1117// Batched Q8_0 matrix-vector multiply for M input vectors.
1118// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1119kernel void matmul_vec_q8_batch(
1120    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1121    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1122    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1123    constant uint& num_tokens    [[buffer(3)]],  // M
1124    constant uint& rows          [[buffer(4)]],
1125    constant uint& cols          [[buffer(5)]],
1126    uint tgid [[threadgroup_position_in_grid]],
1127    uint tid [[thread_index_in_threadgroup]],
1128    uint simd_lane [[thread_index_in_simdgroup]],
1129    uint simd_id [[simdgroup_index_in_threadgroup]])
1130{
1131    uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1132    uint token = tgid / row_tgs;
1133    uint tg_in_token = tgid % row_tgs;
1134    if (token >= num_tokens) return;
1135
1136    // Load this token's input vector into shared memory
1137    threadgroup float vec_tile[VEC_TILE_SIZE];
1138    device const float* input = inputs + token * cols;
1139    for (uint i = tid; i < cols; i += 256) {
1140        vec_tile[i] = input[i];
1141    }
1142    threadgroup_barrier(mem_flags::mem_threadgroup);
1143
1144    uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1145    if (row_base >= rows) return;
1146
1147    uint blocks_per_row = cols / 32;
1148    uint row_bytes = blocks_per_row * 34;
1149
1150    device const uchar* r0 = matrix + row_base * row_bytes;
1151    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1152    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1153    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1154
1155    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1156
1157    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1158        uint bb = blk * 34;
1159        uint vb = blk * 32;
1160
1161        float sc0 = float(*(device const half*)(r0 + bb));
1162        float sc1 = float(*(device const half*)(r1 + bb));
1163        float sc2 = float(*(device const half*)(r2 + bb));
1164        float sc3 = float(*(device const half*)(r3 + bb));
1165
1166        // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1167        // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1168        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1169        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1170        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1171        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1172
1173        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1174        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1175        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1176        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1177        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1178        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1179        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1180        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1181
1182        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1183            short4 _s = short4(SHORT4); \
1184            char2 _a = as_type<char2>(_s.x); \
1185            char2 _b = as_type<char2>(_s.y); \
1186            char2 _c = as_type<char2>(_s.z); \
1187            char2 _d = as_type<char2>(_s.w); \
1188            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1189            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1190        }
1191
1192        float4 f0, f1;
1193        float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1194
1195        Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1196        Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1197        Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1198        Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1199
1200        Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1201        Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1202        Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1203        Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1204
1205        Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1206        Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1207        Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1208        Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1209
1210        Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1211        Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1212        Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1213        Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1214
1215        #undef Q8_UNPACK8
1216
1217        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1218    }
1219
1220    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1221    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1222
1223    device float* output = outputs + token * rows;
1224    if (simd_lane == 0) {
1225        if (row_base     < rows) output[row_base]     = sum0;
1226        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1227        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1228        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1229    }
1230}
1231
1232// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1233// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1234// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1235// the Q8_0 weight blocks are fetched once from device memory and reused for
1236// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1237// per-token dispatch).
1238//
1239// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1240// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1241// with simd_sum.  Token vectors are read directly from device memory inside
1242// the block loop (not cached in shared memory) so intermediate_size up to
1243// 8192 fits without spilling threadgroup memory.
1244constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1245
1246kernel void matmul_q8_gemm_batch(
1247    device const uchar* matrix   [[buffer(0)]],  // Q8_0 raw bytes [rows, cols]
1248    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1249    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1250    constant uint& num_tokens    [[buffer(3)]],  // M
1251    constant uint& rows          [[buffer(4)]],
1252    constant uint& cols          [[buffer(5)]],
1253    uint2 tgid [[threadgroup_position_in_grid]],
1254    uint tid [[thread_index_in_threadgroup]],
1255    uint simd_lane [[thread_index_in_simdgroup]],
1256    uint simd_id [[simdgroup_index_in_threadgroup]])
1257{
1258    uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1259    uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1260    if (row_base >= rows || tok_base >= num_tokens) return;
1261
1262    // How many tokens in this tile are valid?
1263    uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1264
1265    uint blocks_per_row = cols / 32;
1266    uint row_bytes = blocks_per_row * 34;
1267
1268    device const uchar* r0 = matrix + row_base * row_bytes;
1269    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1270    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1271    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1272
1273    // Accumulators: 4 tokens × 4 rows per simdgroup.
1274    float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1275    float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1276    float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1277    float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1278
1279    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1280        uint bb = blk * 34;
1281        uint vb = blk * 32;
1282
1283        // ── Load weight data ONCE per block (reused across all tokens) ──
1284        float sc0 = float(*(device const half*)(r0 + bb));
1285        float sc1 = float(*(device const half*)(r1 + bb));
1286        float sc2 = float(*(device const half*)(r2 + bb));
1287        float sc3 = float(*(device const half*)(r3 + bb));
1288
1289        device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1290        device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1291        device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1292        device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1293
1294        #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1295            short4 _s = short4(SHORT4); \
1296            char2 _a = as_type<char2>(_s.x); \
1297            char2 _b = as_type<char2>(_s.y); \
1298            char2 _c = as_type<char2>(_s.z); \
1299            char2 _d = as_type<char2>(_s.w); \
1300            (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1301            (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1302        }
1303
1304        // Unpack all 4 rows × 8 float4 weights (scaled).  These live in
1305        // registers for the duration of the block and are dotted against
1306        // every token's vector tile.
1307        float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1308        float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1309        float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1310        float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1311
1312        Q8_UNPACK8(d0[0], w0_0, w0_1);
1313        Q8_UNPACK8(d0[1], w0_2, w0_3);
1314        Q8_UNPACK8(d0[2], w0_4, w0_5);
1315        Q8_UNPACK8(d0[3], w0_6, w0_7);
1316
1317        Q8_UNPACK8(d1[0], w1_0, w1_1);
1318        Q8_UNPACK8(d1[1], w1_2, w1_3);
1319        Q8_UNPACK8(d1[2], w1_4, w1_5);
1320        Q8_UNPACK8(d1[3], w1_6, w1_7);
1321
1322        Q8_UNPACK8(d2[0], w2_0, w2_1);
1323        Q8_UNPACK8(d2[1], w2_2, w2_3);
1324        Q8_UNPACK8(d2[2], w2_4, w2_5);
1325        Q8_UNPACK8(d2[3], w2_6, w2_7);
1326
1327        Q8_UNPACK8(d3[0], w3_0, w3_1);
1328        Q8_UNPACK8(d3[1], w3_2, w3_3);
1329        Q8_UNPACK8(d3[2], w3_4, w3_5);
1330        Q8_UNPACK8(d3[3], w3_6, w3_7);
1331
1332        #undef Q8_UNPACK8
1333
1334        // ── For each token, read vector and accumulate against shared weights ──
1335        // Token 0 (always valid: tok_count >= 1).
1336        {
1337            device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1338            float4 v0 = *(device const float4*)(a0);
1339            float4 v1 = *(device const float4*)(a0 + 4);
1340            float4 v2 = *(device const float4*)(a0 + 8);
1341            float4 v3 = *(device const float4*)(a0 + 12);
1342            float4 v4 = *(device const float4*)(a0 + 16);
1343            float4 v5 = *(device const float4*)(a0 + 20);
1344            float4 v6 = *(device const float4*)(a0 + 24);
1345            float4 v7 = *(device const float4*)(a0 + 28);
1346            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1347                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1348            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1349                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1350            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1351                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1352            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1353                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1354            s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1355        }
1356        // Token 1
1357        if (tok_count > 1) {
1358            device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1359            float4 v0 = *(device const float4*)(a1);
1360            float4 v1 = *(device const float4*)(a1 + 4);
1361            float4 v2 = *(device const float4*)(a1 + 8);
1362            float4 v3 = *(device const float4*)(a1 + 12);
1363            float4 v4 = *(device const float4*)(a1 + 16);
1364            float4 v5 = *(device const float4*)(a1 + 20);
1365            float4 v6 = *(device const float4*)(a1 + 24);
1366            float4 v7 = *(device const float4*)(a1 + 28);
1367            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1368                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1369            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1370                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1371            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1372                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1373            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1374                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1375            s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1376        }
1377        // Token 2
1378        if (tok_count > 2) {
1379            device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1380            float4 v0 = *(device const float4*)(a2);
1381            float4 v1 = *(device const float4*)(a2 + 4);
1382            float4 v2 = *(device const float4*)(a2 + 8);
1383            float4 v3 = *(device const float4*)(a2 + 12);
1384            float4 v4 = *(device const float4*)(a2 + 16);
1385            float4 v5 = *(device const float4*)(a2 + 20);
1386            float4 v6 = *(device const float4*)(a2 + 24);
1387            float4 v7 = *(device const float4*)(a2 + 28);
1388            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1389                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1390            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1391                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1392            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1393                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1394            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1395                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1396            s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1397        }
1398        // Token 3
1399        if (tok_count > 3) {
1400            device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1401            float4 v0 = *(device const float4*)(a3);
1402            float4 v1 = *(device const float4*)(a3 + 4);
1403            float4 v2 = *(device const float4*)(a3 + 8);
1404            float4 v3 = *(device const float4*)(a3 + 12);
1405            float4 v4 = *(device const float4*)(a3 + 16);
1406            float4 v5 = *(device const float4*)(a3 + 20);
1407            float4 v6 = *(device const float4*)(a3 + 24);
1408            float4 v7 = *(device const float4*)(a3 + 28);
1409            float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1410                      + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1411            float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1412                      + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1413            float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1414                      + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1415            float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1416                      + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1417            s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1418        }
1419    }
1420
1421    // simdgroup reduction
1422    s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1423    s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1424    s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1425    s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1426
1427    if (simd_lane == 0) {
1428        device float* o0 = outputs + (tok_base + 0) * rows;
1429        if (row_base     < rows) o0[row_base]     = s00;
1430        if (row_base + 1 < rows) o0[row_base + 1] = s01;
1431        if (row_base + 2 < rows) o0[row_base + 2] = s02;
1432        if (row_base + 3 < rows) o0[row_base + 3] = s03;
1433
1434        if (tok_count > 1) {
1435            device float* o1 = outputs + (tok_base + 1) * rows;
1436            if (row_base     < rows) o1[row_base]     = s10;
1437            if (row_base + 1 < rows) o1[row_base + 1] = s11;
1438            if (row_base + 2 < rows) o1[row_base + 2] = s12;
1439            if (row_base + 3 < rows) o1[row_base + 3] = s13;
1440        }
1441        if (tok_count > 2) {
1442            device float* o2 = outputs + (tok_base + 2) * rows;
1443            if (row_base     < rows) o2[row_base]     = s20;
1444            if (row_base + 1 < rows) o2[row_base + 1] = s21;
1445            if (row_base + 2 < rows) o2[row_base + 2] = s22;
1446            if (row_base + 3 < rows) o2[row_base + 3] = s23;
1447        }
1448        if (tok_count > 3) {
1449            device float* o3 = outputs + (tok_base + 3) * rows;
1450            if (row_base     < rows) o3[row_base]     = s30;
1451            if (row_base + 1 < rows) o3[row_base + 1] = s31;
1452            if (row_base + 2 < rows) o3[row_base + 2] = s32;
1453            if (row_base + 3 < rows) o3[row_base + 3] = s33;
1454        }
1455    }
1456}
1457
1458// ── matmul_vec_q4_batch ────────────────────────────────────────────────
1459// Batched Q4_0 matrix-vector multiply for M input vectors.
1460// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
1461kernel void matmul_vec_q4_batch(
1462    device const uchar* matrix   [[buffer(0)]],  // Q4_0 raw bytes [rows, cols]
1463    device const float* inputs   [[buffer(1)]],  // [M, cols] input batch
1464    device float* outputs        [[buffer(2)]],  // [M, rows] output batch
1465    constant uint& num_tokens    [[buffer(3)]],  // M
1466    constant uint& rows          [[buffer(4)]],
1467    constant uint& cols          [[buffer(5)]],
1468    uint tgid [[threadgroup_position_in_grid]],
1469    uint tid [[thread_index_in_threadgroup]],
1470    uint simd_lane [[thread_index_in_simdgroup]],
1471    uint simd_id [[simdgroup_index_in_threadgroup]])
1472{
1473    uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
1474    uint token = tgid / row_tgs;
1475    uint tg_in_token = tgid % row_tgs;
1476    if (token >= num_tokens) return;
1477
1478    threadgroup float vec_tile[VEC_TILE_SIZE];
1479    device const float* input = inputs + token * cols;
1480    for (uint i = tid; i < cols; i += 256) {
1481        vec_tile[i] = input[i];
1482    }
1483    threadgroup_barrier(mem_flags::mem_threadgroup);
1484
1485    uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
1486    if (row_base >= rows) return;
1487
1488    uint blocks_per_row = cols / 32;
1489    uint row_bytes = blocks_per_row * 18;
1490
1491    device const uchar* r0 = matrix + row_base * row_bytes;
1492    device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1493    device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1494    device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1495
1496    float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1497
1498    for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1499        uint bb = blk * 18;
1500        uint vb = blk * 32;
1501
1502        float sc0 = float(*(device const half*)(r0 + bb));
1503        float sc1 = float(*(device const half*)(r1 + bb));
1504        float sc2 = float(*(device const half*)(r2 + bb));
1505        float sc3 = float(*(device const half*)(r3 + bb));
1506
1507        device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
1508        device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
1509        device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
1510        device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
1511
1512        float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1513        float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1514        float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1515        float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1516        float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1517        float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1518        float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1519        float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1520
1521        float bd0=0, bd1=0, bd2=0, bd3=0;
1522        uchar4 b;
1523
1524        b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
1525        b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
1526        b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
1527        b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
1528
1529        b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
1530        b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
1531        b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
1532        b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
1533
1534        b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
1535        b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
1536        b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
1537        b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
1538
1539        b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
1540        b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
1541        b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
1542        b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
1543
1544        sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1545    }
1546
1547    sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1548    sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1549
1550    device float* output = outputs + token * rows;
1551    if (simd_lane == 0) {
1552        if (row_base     < rows) output[row_base]     = sum0;
1553        if (row_base + 1 < rows) output[row_base + 1] = sum1;
1554        if (row_base + 2 < rows) output[row_base + 2] = sum2;
1555        if (row_base + 3 < rows) output[row_base + 3] = sum3;
1556    }
1557}
1558
1559// ── copy_kv_batch ─────────────────────────────────────────────────────
1560// Copy K or V from a strided batch QKV buffer to the KV cache.
1561// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
1562// dst layout: contiguous [max_seq, kv_dim] cache.
1563kernel void copy_kv_batch(
1564    device const float* src  [[buffer(0)]],  // batch QKV buffer
1565    device float* dst        [[buffer(1)]],  // KV cache
1566    constant uint& M         [[buffer(2)]],  // num tokens in batch
1567    constant uint& kv_dim    [[buffer(3)]],  // floats per KV vector
1568    constant uint& base_pos  [[buffer(4)]],  // starting position in cache
1569    constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
1570    constant uint& src_offset [[buffer(6)]], // float offset within each src row
1571    uint id [[thread_position_in_grid]])
1572{
1573    uint total = M * kv_dim;
1574    if (id >= total) return;
1575    uint token = id / kv_dim;
1576    uint d = id % kv_dim;
1577    uint dst_off = (base_pos + token) * kv_dim + d;
1578    uint src_off = token * src_stride + src_offset + d;
1579    dst[dst_off] = src[src_off];
1580}
1581
1582// ── attention_batch ───────────────────────────────────────────────────
1583// Batched causal attention for prefill. Processes M tokens in one dispatch.
1584// Each threadgroup handles one (token, head) pair.
1585// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
1586// Causal masking: token i can only attend to positions 0..base_pos+i.
1587kernel void attention_batch(
1588    device const float* q_batch      [[buffer(0)]],  // batch QKV buf (strided)
1589    device const float* k_cache      [[buffer(1)]],  // [max_seq, num_kv_heads * head_dim]
1590    device const float* v_cache      [[buffer(2)]],  // [max_seq, num_kv_heads * head_dim]
1591    device float* output_batch       [[buffer(3)]],  // [M, num_heads * head_dim]
1592    constant uint& M                 [[buffer(4)]],  // num tokens in batch
1593    constant uint& base_pos          [[buffer(5)]],  // starting position in KV cache
1594    constant uint& num_heads         [[buffer(6)]],
1595    constant uint& num_kv_heads      [[buffer(7)]],
1596    constant uint& head_dim          [[buffer(8)]],
1597    constant uint& q_stride          [[buffer(9)]],  // floats per row in q_batch
1598    uint tgid [[threadgroup_position_in_grid]],
1599    uint tid [[thread_index_in_threadgroup]],
1600    uint simd_lane [[thread_index_in_simdgroup]],
1601    uint simd_id [[simdgroup_index_in_threadgroup]])
1602{
1603    // Grid: M * num_heads threadgroups
1604    uint token_idx = tgid / num_heads;
1605    uint head = tgid % num_heads;
1606    if (token_idx >= M) return;
1607
1608    uint kv_head = head / (num_heads / num_kv_heads);
1609    uint seq_len = base_pos + token_idx + 1;  // causal: see positions 0..base_pos+token_idx
1610
1611    // Q offset uses strided layout (from batch QKV buffer)
1612    uint q_off = token_idx * q_stride + head * head_dim;
1613    // Output is contiguous [M, num_heads * head_dim]
1614    uint out_off = token_idx * num_heads * head_dim + head * head_dim;
1615
1616    // Shared memory for attention scores
1617    threadgroup float scores[2048];
1618
1619    // Step 1: Q * K^T with simdgroup reduction
1620    for (uint s = simd_id; s < seq_len; s += 8) {
1621        uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
1622        float dot = 0.0;
1623        for (uint d = simd_lane; d < head_dim; d += 32) {
1624            dot += q_batch[q_off + d] * k_cache[k_off + d];
1625        }
1626        dot = simd_sum(dot);
1627        if (simd_lane == 0) {
1628            scores[s] = dot * fast::rsqrt(float(head_dim));
1629        }
1630    }
1631    threadgroup_barrier(mem_flags::mem_threadgroup);
1632
1633    // Step 2: Softmax (cooperative)
1634    float local_max = -INFINITY;
1635    for (uint s = tid; s < seq_len; s += 256) {
1636        local_max = max(local_max, scores[s]);
1637    }
1638    local_max = simd_max(local_max);
1639    threadgroup float shared_max[8];
1640    if (simd_lane == 0) shared_max[simd_id] = local_max;
1641    threadgroup_barrier(mem_flags::mem_threadgroup);
1642    if (tid == 0) {
1643        float m = shared_max[0];
1644        for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
1645        shared_max[0] = m;
1646    }
1647    threadgroup_barrier(mem_flags::mem_threadgroup);
1648    float max_val = shared_max[0];
1649
1650    float local_sum = 0.0;
1651    for (uint s = tid; s < seq_len; s += 256) {
1652        scores[s] = fast::exp(scores[s] - max_val);
1653        local_sum += scores[s];
1654    }
1655    local_sum = simd_sum(local_sum);
1656    threadgroup float shared_sum[8];
1657    if (simd_lane == 0) shared_sum[simd_id] = local_sum;
1658    threadgroup_barrier(mem_flags::mem_threadgroup);
1659    if (tid == 0) {
1660        float total = 0.0;
1661        for (uint i = 0; i < 8; i++) total += shared_sum[i];
1662        shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
1663    }
1664    threadgroup_barrier(mem_flags::mem_threadgroup);
1665    float inv_sum = shared_sum[0];
1666    for (uint s = tid; s < seq_len; s += 256) {
1667        scores[s] *= inv_sum;
1668    }
1669    threadgroup_barrier(mem_flags::mem_threadgroup);
1670
1671    // Step 3: scores * V using float4 vectorized loads
1672    // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
1673    // This is much better than the scalar version where only 64 of 256 threads are active.
1674    uint v_stride = num_kv_heads * head_dim;
1675    uint head_dim4 = head_dim / 4;
1676    for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
1677        uint d = d4 * 4;
1678        float4 acc = float4(0.0);
1679        uint v_base = kv_head * head_dim + d;
1680        uint seq_len4 = seq_len & ~3u;
1681        for (uint s = 0; s < seq_len4; s += 4) {
1682            float sc0 = scores[s];
1683            float sc1 = scores[s + 1];
1684            float sc2 = scores[s + 2];
1685            float sc3 = scores[s + 3];
1686            acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
1687                 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
1688                 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
1689                 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
1690        }
1691        for (uint s = seq_len4; s < seq_len; s++) {
1692            acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
1693        }
1694        *(device float4*)(output_batch + out_off + d) = acc;
1695    }
1696    // Handle remaining dimensions not divisible by 4 (scalar fallback)
1697    for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
1698        float acc = 0.0;
1699        uint v_base = kv_head * head_dim + d;
1700        for (uint s = 0; s < seq_len; s++) {
1701            acc += scores[s] * v_cache[s * v_stride + v_base];
1702        }
1703        output_batch[out_off + d] = acc;
1704    }
1705}
1706
1707// ── rope_qk_batch ─────────────────────────────────────────────────────
1708// Fused RoPE for both Q and K in a single dispatch, saving one kernel
1709// launch + memory barrier per layer. Both Q and K live in the same
1710// qkv_data buffer at different offsets within each token's row.
1711// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
1712kernel void rope_qk_batch(
1713    device float* qkv_data           [[buffer(0)]],  // [M, qkv_stride]
1714    constant uint& M                 [[buffer(1)]],   // num tokens
1715    constant uint& base_pos          [[buffer(2)]],   // starting position
1716    constant uint& num_q_heads       [[buffer(3)]],
1717    constant uint& num_kv_heads      [[buffer(4)]],
1718    constant uint& head_dim          [[buffer(5)]],
1719    constant uint& qkv_stride        [[buffer(6)]],   // floats per row
1720    constant float& theta            [[buffer(7)]],
1721    uint id [[thread_position_in_grid]])
1722{
1723    uint half_dim = head_dim / 2;
1724    uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
1725    uint token = id / total_pairs;
1726    uint pair = id % total_pairs;
1727    if (token >= M) return;
1728
1729    uint pos = base_pos + token;
1730    uint q_pairs = num_q_heads * half_dim;
1731
1732    uint h, i, offset;
1733    if (pair < q_pairs) {
1734        // Q head
1735        h = pair / half_dim;
1736        i = pair % half_dim;
1737        offset = token * qkv_stride + h * head_dim + i * 2;
1738    } else {
1739        // K head
1740        uint kp = pair - q_pairs;
1741        h = kp / half_dim;
1742        i = kp % half_dim;
1743        uint k_start = num_q_heads * head_dim;
1744        offset = token * qkv_stride + k_start + h * head_dim + i * 2;
1745    }
1746
1747    float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
1748    float angle = float(pos) * freq;
1749    float cos_val = cos(angle);
1750    float sin_val = sin(angle);
1751
1752    float x0 = qkv_data[offset];
1753    float x1 = qkv_data[offset + 1];
1754    qkv_data[offset]     = x0 * cos_val - x1 * sin_val;
1755    qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
1756}
1757
1758// ── copy_kv_both_batch ────────────────────────────────────────────────
1759// Fused K+V cache copy in a single dispatch: copies both K and V from
1760// the strided batch QKV buffer to their respective KV cache buffers.
1761// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
1762kernel void copy_kv_both_batch(
1763    device const float* src    [[buffer(0)]],  // batch QKV buffer [M, qkv_stride]
1764    device float* k_dst        [[buffer(1)]],  // K cache [max_seq, kv_dim]
1765    device float* v_dst        [[buffer(2)]],  // V cache [max_seq, kv_dim]
1766    constant uint& M           [[buffer(3)]],  // num tokens in batch
1767    constant uint& kv_dim      [[buffer(4)]],  // floats per KV vector
1768    constant uint& base_pos    [[buffer(5)]],  // starting position in cache
1769    constant uint& src_stride  [[buffer(6)]],  // floats per row in src (qkv_stride)
1770    constant uint& k_offset    [[buffer(7)]],  // float offset of K within each src row
1771    constant uint& v_offset    [[buffer(8)]],  // float offset of V within each src row
1772    uint id [[thread_position_in_grid]])
1773{
1774    // Total elements = M * kv_dim * 2 (K + V)
1775    uint total_kv = M * kv_dim;
1776    if (id >= total_kv * 2) return;
1777
1778    uint is_v = id / total_kv;        // 0 = K, 1 = V
1779    uint local_id = id % total_kv;
1780    uint token = local_id / kv_dim;
1781    uint d = local_id % kv_dim;
1782
1783    uint dst_off = (base_pos + token) * kv_dim + d;
1784    uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
1785
1786    if (is_v) {
1787        v_dst[dst_off] = src[src_off];
1788    } else {
1789        k_dst[dst_off] = src[src_off];
1790    }
1791}
1792"#
1793    .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
1794}
1795
1796// ---------------------------------------------------------------------------
1797// model.rs generation
1798// ---------------------------------------------------------------------------
1799
1800fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
1801    let mut code = String::with_capacity(48 * 1024);
1802    emit_model_header(&mut code, config)?;
1803    emit_metal_model_struct(&mut code, config)?;
1804    emit_layer_buffers_struct(&mut code)?;
1805    emit_metal_model_impl(&mut code, config)?;
1806    emit_helper_functions(&mut code)?;
1807    Ok(code)
1808}
1809
1810fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
1811    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
1812    writeln!(
1813        code,
1814        "//! Model: {} ({} layers, hidden={})",
1815        config.architecture, config.num_layers, config.hidden_size
1816    )?;
1817    writeln!(code, "//!")?;
1818    writeln!(
1819        code,
1820        "//! Uses native Metal compute pipelines via the metal crate."
1821    )?;
1822    writeln!(code)?;
1823    writeln!(code, "#![allow(dead_code)]")?;
1824    writeln!(code)?;
1825    writeln!(code, "use metal::*;")?;
1826    writeln!(code, "#[allow(unused_imports)]")?;
1827    writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
1828    writeln!(code, "use std::mem;")?;
1829    writeln!(code)?;
1830
1831    // Model constants
1832    writeln!(
1833        code,
1834        "// ── Model constants ──────────────────────────────────"
1835    )?;
1836    writeln!(
1837        code,
1838        "pub const HIDDEN_SIZE: usize = {};",
1839        config.hidden_size
1840    )?;
1841    writeln!(
1842        code,
1843        "pub const INTERMEDIATE_SIZE: usize = {};",
1844        config.intermediate_size
1845    )?;
1846    writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
1847    writeln!(
1848        code,
1849        "pub const NUM_HEADS: usize = {};",
1850        config.num_attention_heads
1851    )?;
1852    writeln!(
1853        code,
1854        "pub const NUM_KV_HEADS: usize = {};",
1855        config.num_kv_heads
1856    )?;
1857    writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
1858    writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
1859    let effective_seq_len = config.max_seq_len.min(4096);
1860    writeln!(
1861        code,
1862        "pub const MAX_SEQ_LEN: usize = {};  // capped from model's {}",
1863        effective_seq_len, config.max_seq_len
1864    )?;
1865    writeln!(
1866        code,
1867        "pub const RMS_NORM_EPS: f32 = {:e};",
1868        config.rms_norm_eps
1869    )?;
1870    writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
1871    writeln!(
1872        code,
1873        "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
1874    )?;
1875    writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
1876    writeln!(code)?;
1877
1878    Ok(())
1879}
1880
1881fn emit_metal_model_struct(
1882    code: &mut String,
1883    _config: &ModelConfig,
1884) -> Result<(), MetalCodegenError> {
1885    writeln!(
1886        code,
1887        "// ── MetalModel ──────────────────────────────────────────"
1888    )?;
1889    writeln!(code)?;
1890    writeln!(
1891        code,
1892        "/// Metal-accelerated transformer model for Apple Silicon."
1893    )?;
1894    writeln!(code, "///")?;
1895    writeln!(
1896        code,
1897        "/// Uses unified memory for zero-copy weight access and native Metal"
1898    )?;
1899    writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
1900    writeln!(code, "pub struct MetalModel {{")?;
1901    writeln!(code, "    device: Device,")?;
1902    writeln!(code, "    queue: CommandQueue,")?;
1903    writeln!(code)?;
1904    writeln!(code, "    // ── Compute pipelines ──")?;
1905    writeln!(code, "    matmul_pipeline: ComputePipelineState,")?;
1906    writeln!(code, "    matmul_q8_pipeline: ComputePipelineState,")?;
1907    writeln!(code, "    matmul_q4_pipeline: ComputePipelineState,")?;
1908    writeln!(code, "    rms_norm_pipeline: ComputePipelineState,")?;
1909    writeln!(code, "    rope_pipeline: ComputePipelineState,")?;
1910    writeln!(code, "    softmax_pipeline: ComputePipelineState,")?;
1911    writeln!(code, "    silu_mul_pipeline: ComputePipelineState,")?;
1912    writeln!(code, "    silu_mul_fused_pipeline: ComputePipelineState,")?;
1913    writeln!(code, "    add_pipeline: ComputePipelineState,")?;
1914    writeln!(code, "    attention_pipeline: ComputePipelineState,")?;
1915    writeln!(code, "    add_inplace_pipeline: ComputePipelineState,")?;
1916    writeln!(code, "    copy_pipeline: ComputePipelineState,")?;
1917    writeln!(code, "    copy_offset_pipeline: ComputePipelineState,")?;
1918    writeln!(code)?;
1919    writeln!(code, "    // ── Batched prefill pipelines ──")?;
1920    writeln!(code, "    matmul_batch_pipeline: ComputePipelineState,")?;
1921    writeln!(code, "    matmul_q8_batch_pipeline: ComputePipelineState,")?;
1922    writeln!(
1923        code,
1924        "    matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
1925    )?;
1926    writeln!(code, "    matmul_q4_batch_pipeline: ComputePipelineState,")?;
1927    writeln!(code, "    rms_norm_batch_pipeline: ComputePipelineState,")?;
1928    writeln!(code, "    rope_batch_pipeline: ComputePipelineState,")?;
1929    writeln!(
1930        code,
1931        "    silu_mul_fused_batch_pipeline: ComputePipelineState,"
1932    )?;
1933    writeln!(
1934        code,
1935        "    add_inplace_batch_pipeline: ComputePipelineState,"
1936    )?;
1937    writeln!(
1938        code,
1939        "    copy_embedding_batch_pipeline: ComputePipelineState,"
1940    )?;
1941    writeln!(code, "    attention_batch_pipeline: ComputePipelineState,")?;
1942    writeln!(code, "    copy_kv_batch_pipeline: ComputePipelineState,")?;
1943    writeln!(code, "    rope_qk_batch_pipeline: ComputePipelineState,")?;
1944    writeln!(
1945        code,
1946        "    copy_kv_both_batch_pipeline: ComputePipelineState,"
1947    )?;
1948    writeln!(code)?;
1949    writeln!(code, "    // ── Weight buffers (Metal shared memory) ──")?;
1950    writeln!(
1951        code,
1952        "    /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
1953    )?;
1954    writeln!(code, "    embed_buf: Buffer,")?;
1955    writeln!(code)?;
1956    writeln!(code, "    /// Per-layer weight buffers")?;
1957    writeln!(code, "    layers: Vec<LayerBuffers>,")?;
1958    writeln!(code)?;
1959    writeln!(code, "    /// Final layer-norm weight [HIDDEN_SIZE]")?;
1960    writeln!(code, "    norm_buf: Buffer,")?;
1961    writeln!(code)?;
1962    writeln!(
1963        code,
1964        "    /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
1965    )?;
1966    writeln!(code, "    lm_head_buf: Buffer,")?;
1967    writeln!(code)?;
1968    writeln!(
1969        code,
1970        "    // ── Working buffers (pre-allocated, reused every forward pass) ──"
1971    )?;
1972    writeln!(code, "    hidden_buf: Buffer,")?;
1973    writeln!(code, "    residual_buf: Buffer,")?;
1974    writeln!(code, "    normed_buf: Buffer,")?;
1975    writeln!(
1976        code,
1977        "    /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
1978    )?;
1979    writeln!(code, "    qkv_buf: Buffer,")?;
1980    writeln!(code, "    attn_out_buf: Buffer,")?;
1981    writeln!(code, "    attn_proj_buf: Buffer,")?;
1982    writeln!(
1983        code,
1984        "    /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
1985    )?;
1986    writeln!(code, "    gate_up_buf: Buffer,")?;
1987    writeln!(code, "    ffn_hidden_buf: Buffer,")?;
1988    writeln!(code, "    ffn_out_buf: Buffer,")?;
1989    writeln!(code, "    add_tmp_buf: Buffer,")?;
1990    writeln!(code, "    logits_buf: Buffer,")?;
1991    writeln!(code)?;
1992    writeln!(code, "    // ── Batched prefill working buffers ──")?;
1993    writeln!(code, "    /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
1994    writeln!(code, "    batch_hidden_buf: Buffer,")?;
1995    writeln!(
1996        code,
1997        "    /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
1998    )?;
1999    writeln!(code, "    batch_residual_buf: Buffer,")?;
2000    writeln!(code, "    /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
2001    writeln!(code, "    batch_qkv_buf: Buffer,")?;
2002    writeln!(
2003        code,
2004        "    /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
2005    )?;
2006    writeln!(code, "    batch_attn_out_buf: Buffer,")?;
2007    writeln!(
2008        code,
2009        "    /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
2010    )?;
2011    writeln!(code, "    batch_attn_proj_buf: Buffer,")?;
2012    writeln!(
2013        code,
2014        "    /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
2015    )?;
2016    writeln!(code, "    batch_gate_up_buf: Buffer,")?;
2017    writeln!(
2018        code,
2019        "    /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
2020    )?;
2021    writeln!(code, "    batch_ffn_hidden_buf: Buffer,")?;
2022    writeln!(code, "    /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
2023    writeln!(code, "    batch_ffn_out_buf: Buffer,")?;
2024    writeln!(code, "    /// Token IDs buffer for batch embedding lookup")?;
2025    writeln!(code, "    batch_tokens_buf: Buffer,")?;
2026    writeln!(code, "    /// Positions buffer for batch RoPE")?;
2027    writeln!(code, "    batch_positions_buf: Buffer,")?;
2028    writeln!(code)?;
2029    writeln!(code, "    // ── KV cache buffers (per-layer) ──")?;
2030    writeln!(code, "    k_cache: Vec<Buffer>,  // per-layer")?;
2031    writeln!(code, "    v_cache: Vec<Buffer>,  // per-layer")?;
2032    writeln!(code)?;
2033    writeln!(code, "    // ── Inference state ──")?;
2034    writeln!(code, "    pos: usize,")?;
2035    writeln!(code)?;
2036    writeln!(
2037        code,
2038        "    /// Previous command buffer for double-buffered prefill."
2039    )?;
2040    writeln!(
2041        code,
2042        "    /// While the GPU executes token N, the CPU can encode token N+1."
2043    )?;
2044    writeln!(code, "    prev_cmd: Option<CommandBuffer>,")?;
2045    writeln!(code, "}}")?;
2046    writeln!(code)?;
2047
2048    Ok(())
2049}
2050
2051fn emit_layer_buffers_struct(code: &mut String) -> Result<(), MetalCodegenError> {
2052    writeln!(
2053        code,
2054        "/// Per-layer weight buffers for attention and FFN projections."
2055    )?;
2056    writeln!(code, "struct LayerBuffers {{")?;
2057    writeln!(code, "    attn_norm: Buffer,")?;
2058    writeln!(
2059        code,
2060        "    /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
2061    )?;
2062    writeln!(code, "    qkv_weight: Buffer,")?;
2063    writeln!(code, "    o_weight: Buffer,")?;
2064    writeln!(code, "    ffn_norm: Buffer,")?;
2065    writeln!(
2066        code,
2067        "    /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
2068    )?;
2069    writeln!(code, "    gate_up_weight: Buffer,")?;
2070    writeln!(code, "    down_weight: Buffer,")?;
2071    writeln!(code, "}}")?;
2072    writeln!(code)?;
2073
2074    Ok(())
2075}
2076
2077fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2078    let hidden = config.hidden_size;
2079    let intermediate = config.intermediate_size;
2080    let _num_layers = config.num_layers;
2081    let num_heads = config.num_attention_heads;
2082    let num_kv_heads = config.num_kv_heads;
2083    let head_dim = config.head_dim;
2084    let vocab = config.vocab_size;
2085    let effective_seq_len = config.max_seq_len.min(4096);
2086    let is_q8 = config.dtype == DType::Q8_0;
2087    let is_q4 = config.dtype == DType::Q4_0;
2088    let kv_dim = num_kv_heads * head_dim;
2089
2090    writeln!(code, "impl MetalModel {{")?;
2091
2092    // ── new() ──
2093    writeln!(
2094        code,
2095        "    /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
2096    )?;
2097    writeln!(code, "    ///")?;
2098    writeln!(
2099        code,
2100        "    /// `weights` is the raw weight blob produced by `forge export-weights`."
2101    )?;
2102    writeln!(code, "    pub fn new(weights: &[u8]) -> Self {{")?;
2103    writeln!(
2104        code,
2105        "        let device = Device::system_default().expect(\"no Metal device found\");"
2106    )?;
2107    writeln!(code, "        let queue = device.new_command_queue();")?;
2108    writeln!(code)?;
2109
2110    // Compile shaders
2111    writeln!(
2112        code,
2113        "        // Compile Metal shaders from embedded source"
2114    )?;
2115    writeln!(
2116        code,
2117        "        let shader_source = include_str!(\"../shaders/kernels.metal\");"
2118    )?;
2119    writeln!(
2120        code,
2121        "        let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
2122    )?;
2123    writeln!(
2124        code,
2125        "            .expect(\"failed to compile Metal shaders\");"
2126    )?;
2127    writeln!(code)?;
2128
2129    // Create compute pipelines
2130    writeln!(code, "        // Create compute pipelines")?;
2131    for (var, fn_name) in [
2132        ("matmul_pipeline", "matmul_vec"),
2133        ("matmul_q8_pipeline", "matmul_vec_q8"),
2134        ("matmul_q4_pipeline", "matmul_vec_q4"),
2135        ("rms_norm_pipeline", "rms_norm"),
2136        ("rope_pipeline", "rope"),
2137        ("softmax_pipeline", "softmax"),
2138        ("silu_mul_pipeline", "silu_mul"),
2139        ("silu_mul_fused_pipeline", "silu_mul_fused"),
2140        ("add_pipeline", "elementwise_add"),
2141        ("attention_pipeline", "attention"),
2142        ("add_inplace_pipeline", "add_inplace"),
2143        ("copy_pipeline", "copy_buffer"),
2144        ("copy_offset_pipeline", "copy_offset"),
2145        ("matmul_batch_pipeline", "matmul_vec_batch"),
2146        ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
2147        ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
2148        ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
2149        ("rms_norm_batch_pipeline", "rms_norm_batch"),
2150        ("rope_batch_pipeline", "rope_batch"),
2151        ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
2152        ("add_inplace_batch_pipeline", "add_inplace_batch"),
2153        ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
2154        ("attention_batch_pipeline", "attention_batch"),
2155        ("copy_kv_batch_pipeline", "copy_kv_batch"),
2156        ("rope_qk_batch_pipeline", "rope_qk_batch"),
2157        ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
2158    ] {
2159        writeln!(
2160            code,
2161            "        let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
2162        )?;
2163    }
2164    writeln!(code)?;
2165
2166    // Weight loading
2167    writeln!(
2168        code,
2169        "        // Load weights into Metal shared-memory buffers"
2170    )?;
2171    writeln!(code, "        let f32_size = mem::size_of::<f32>();")?;
2172    writeln!(code, "        let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
2173    writeln!(code, "        let hidden_elems = HIDDEN_SIZE;")?;
2174    writeln!(code)?;
2175    writeln!(
2176        code,
2177        "        let cursor = std::cell::Cell::new(0usize);  // byte cursor into `weights`"
2178    )?;
2179    writeln!(code)?;
2180    writeln!(
2181        code,
2182        "        // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
2183    )?;
2184    writeln!(
2185        code,
2186        "        let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
2187    )?;
2188    writeln!(code, "            let byte_len = n * f32_size;")?;
2189    writeln!(code, "            let cur = cursor.get();")?;
2190    writeln!(
2191        code,
2192        "            let data = &weights[cur..cur + byte_len];"
2193    )?;
2194    writeln!(code, "            cursor.set(cur + byte_len);")?;
2195    writeln!(code, "            device.new_buffer_with_data(")?;
2196    writeln!(code, "                data.as_ptr() as *const _,")?;
2197    writeln!(code, "                byte_len as u64,")?;
2198    writeln!(
2199        code,
2200        "                MTLResourceOptions::StorageModeShared,"
2201    )?;
2202    writeln!(code, "            )")?;
2203    writeln!(code, "        }};")?;
2204    writeln!(code)?;
2205
2206    if is_q8 {
2207        // For Q8_0 models, projection weights are stored as raw Q8_0 bytes.
2208        // We load them directly into Metal buffers without dequantizing,
2209        // and use the matmul_vec_q8 shader that operates on quantized data.
2210        // This halves GPU memory usage and memory bandwidth vs f32 dequantization.
2211        writeln!(
2212            code,
2213            "        // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
2214        )?;
2215        writeln!(
2216            code,
2217            "        // as raw bytes into a Metal buffer (no dequantization)."
2218        )?;
2219        writeln!(
2220            code,
2221            "        // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
2222        )?;
2223        writeln!(
2224            code,
2225            "        let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
2226        )?;
2227        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2228        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
2229        writeln!(code, "            let total_raw = rows * row_bytes;")?;
2230        writeln!(code, "            let cur = cursor.get();")?;
2231        writeln!(
2232            code,
2233            "            let data = &weights[cur..cur + total_raw];"
2234        )?;
2235        writeln!(code, "            cursor.set(cur + total_raw);")?;
2236        writeln!(code, "            device.new_buffer_with_data(")?;
2237        writeln!(code, "                data.as_ptr() as *const _,")?;
2238        writeln!(code, "                total_raw as u64,")?;
2239        writeln!(
2240            code,
2241            "                MTLResourceOptions::StorageModeShared,"
2242        )?;
2243        writeln!(code, "            )")?;
2244        writeln!(code, "        }};")?;
2245        writeln!(code)?;
2246        writeln!(
2247            code,
2248            "        // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
2249        )?;
2250        writeln!(
2251            code,
2252            "        // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
2253        )?;
2254        writeln!(
2255            code,
2256            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
2257        )?;
2258        writeln!(
2259            code,
2260            "        let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
2261        )?;
2262        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2263        writeln!(code, "            let row_bytes = blocks_per_row * 34;")?;
2264        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
2265        writeln!(code, "            let cur = cursor.get();")?;
2266        writeln!(
2267            code,
2268            "            let data = &weights[cur..cur + total_raw];"
2269        )?;
2270        writeln!(code, "            cursor.set(cur + total_raw);")?;
2271        writeln!(code, "            device.new_buffer_with_data(")?;
2272        writeln!(code, "                data.as_ptr() as *const _,")?;
2273        writeln!(code, "                total_raw as u64,")?;
2274        writeln!(
2275            code,
2276            "                MTLResourceOptions::StorageModeShared,"
2277        )?;
2278        writeln!(code, "            )")?;
2279        writeln!(code, "        }};")?;
2280        writeln!(code)?;
2281    }
2282
2283    if is_q4 {
2284        // For Q4_0 models, projection weights are stored as raw Q4_0 bytes.
2285        // We load them directly into Metal buffers without dequantizing,
2286        // and use the matmul_vec_q4 shader that operates on quantized data.
2287        // This quarters GPU memory usage vs f32 dequantization.
2288        writeln!(
2289            code,
2290            "        // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
2291        )?;
2292        writeln!(
2293            code,
2294            "        // as raw bytes into a Metal buffer (no dequantization)."
2295        )?;
2296        writeln!(
2297            code,
2298            "        // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
2299        )?;
2300        writeln!(
2301            code,
2302            "        let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
2303        )?;
2304        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2305        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
2306        writeln!(code, "            let total_raw = rows * row_bytes;")?;
2307        writeln!(code, "            let cur = cursor.get();")?;
2308        writeln!(
2309            code,
2310            "            let data = &weights[cur..cur + total_raw];"
2311        )?;
2312        writeln!(code, "            cursor.set(cur + total_raw);")?;
2313        writeln!(code, "            device.new_buffer_with_data(")?;
2314        writeln!(code, "                data.as_ptr() as *const _,")?;
2315        writeln!(code, "                total_raw as u64,")?;
2316        writeln!(
2317            code,
2318            "                MTLResourceOptions::StorageModeShared,"
2319        )?;
2320        writeln!(code, "            )")?;
2321        writeln!(code, "        }};")?;
2322        writeln!(code)?;
2323        writeln!(
2324            code,
2325            "        // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
2326        )?;
2327        writeln!(
2328            code,
2329            "        // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
2330        )?;
2331        writeln!(
2332            code,
2333            "        // Used for fused QKV and gate+up projections where weights are adjacent in the file."
2334        )?;
2335        writeln!(
2336            code,
2337            "        let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
2338        )?;
2339        writeln!(code, "            let blocks_per_row = cols.div_ceil(32);")?;
2340        writeln!(code, "            let row_bytes = blocks_per_row * 18;")?;
2341        writeln!(code, "            let total_raw = total_rows * row_bytes;")?;
2342        writeln!(code, "            let cur = cursor.get();")?;
2343        writeln!(
2344            code,
2345            "            let data = &weights[cur..cur + total_raw];"
2346        )?;
2347        writeln!(code, "            cursor.set(cur + total_raw);")?;
2348        writeln!(code, "            device.new_buffer_with_data(")?;
2349        writeln!(code, "                data.as_ptr() as *const _,")?;
2350        writeln!(code, "                total_raw as u64,")?;
2351        writeln!(
2352            code,
2353            "                MTLResourceOptions::StorageModeShared,"
2354        )?;
2355        writeln!(code, "            )")?;
2356        writeln!(code, "        }};")?;
2357        writeln!(code)?;
2358    }
2359
2360    writeln!(
2361        code,
2362        "        let embed_buf = next_f32_buffer(&device, embed_elems);"
2363    )?;
2364    writeln!(code)?;
2365
2366    // Per-layer weights
2367    writeln!(
2368        code,
2369        "        let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
2370    )?;
2371    writeln!(code, "        for _layer in 0..NUM_LAYERS {{")?;
2372
2373    // attn_norm is always f32
2374    writeln!(
2375        code,
2376        "            let attn_norm = next_f32_buffer(&device, hidden_elems);"
2377    )?;
2378
2379    let qkv_rows = hidden + 2 * kv_dim;
2380    if is_q8 {
2381        // Fused Q+K+V weight: read all three consecutive Q8_0 matrices as one buffer
2382        writeln!(
2383            code,
2384            "            let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
2385        )?;
2386        writeln!(
2387            code,
2388            "            let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
2389        )?;
2390    } else if is_q4 {
2391        // Fused Q+K+V weight: read all three consecutive Q4_0 matrices as one buffer
2392        writeln!(
2393            code,
2394            "            let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
2395        )?;
2396        writeln!(
2397            code,
2398            "            let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
2399        )?;
2400    } else {
2401        // Fused Q+K+V weight: read all three as a single contiguous f32 buffer
2402        writeln!(
2403            code,
2404            "            let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
2405        )?;
2406        writeln!(
2407            code,
2408            "            let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
2409        )?;
2410    }
2411
2412    // ffn_norm is always f32
2413    writeln!(
2414        code,
2415        "            let ffn_norm = next_f32_buffer(&device, hidden_elems);"
2416    )?;
2417
2418    let gate_up_rows = 2 * intermediate;
2419    if is_q8 {
2420        // Fused gate+up weight: read both consecutive Q8_0 matrices as one buffer
2421        writeln!(
2422            code,
2423            "            let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
2424        )?;
2425        writeln!(
2426            code,
2427            "            let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
2428        )?;
2429    } else if is_q4 {
2430        // Fused gate+up weight: read both consecutive Q4_0 matrices as one buffer
2431        writeln!(
2432            code,
2433            "            let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
2434        )?;
2435        writeln!(
2436            code,
2437            "            let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
2438        )?;
2439    } else {
2440        // Fused gate+up weight: read both as a single contiguous f32 buffer
2441        writeln!(
2442            code,
2443            "            let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
2444        )?;
2445        writeln!(
2446            code,
2447            "            let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
2448        )?;
2449    }
2450
2451    writeln!(code, "            layers.push(LayerBuffers {{")?;
2452    writeln!(code, "                attn_norm,")?;
2453    writeln!(code, "                qkv_weight,")?;
2454    writeln!(code, "                o_weight,")?;
2455    writeln!(code, "                ffn_norm,")?;
2456    writeln!(code, "                gate_up_weight,")?;
2457    writeln!(code, "                down_weight,")?;
2458    writeln!(code, "            }});")?;
2459    writeln!(code, "        }}")?;
2460    writeln!(code)?;
2461
2462    // final_norm is always f32
2463    writeln!(
2464        code,
2465        "        let norm_buf = next_f32_buffer(&device, hidden_elems);"
2466    )?;
2467    writeln!(code)?;
2468
2469    // lm_head
2470    if is_q8 {
2471        writeln!(
2472            code,
2473            "        let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
2474        )?;
2475    } else if is_q4 {
2476        writeln!(
2477            code,
2478            "        let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
2479        )?;
2480    } else {
2481        writeln!(
2482            code,
2483            "        let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
2484        )?;
2485    }
2486    writeln!(code)?;
2487
2488    // Working buffers
2489    let hidden_bytes = hidden * 4;
2490    let _kv_dim_bytes = kv_dim * 4;
2491    let intermediate_bytes = intermediate * 4;
2492    let vocab_bytes = vocab * 4;
2493    let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
2494
2495    writeln!(
2496        code,
2497        "        // Allocate working buffers (shared memory for zero-copy)"
2498    )?;
2499    writeln!(
2500        code,
2501        "        let opts = MTLResourceOptions::StorageModeShared;"
2502    )?;
2503    writeln!(
2504        code,
2505        "        let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2506    )?;
2507    writeln!(
2508        code,
2509        "        let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2510    )?;
2511    let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
2512    writeln!(
2513        code,
2514        "        let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2515    )?;
2516    writeln!(
2517        code,
2518        "        // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
2519    )?;
2520    writeln!(
2521        code,
2522        "        let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
2523    )?;
2524    writeln!(
2525        code,
2526        "        let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2527    )?;
2528    writeln!(
2529        code,
2530        "        let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2531    )?;
2532    let gate_up_buf_bytes = 2 * intermediate * 4;
2533    writeln!(
2534        code,
2535        "        // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
2536    )?;
2537    writeln!(
2538        code,
2539        "        let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
2540    )?;
2541    writeln!(
2542        code,
2543        "        let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
2544    )?;
2545    writeln!(
2546        code,
2547        "        let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2548    )?;
2549    writeln!(
2550        code,
2551        "        let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2552    )?;
2553    writeln!(
2554        code,
2555        "        let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
2556    )?;
2557    writeln!(code)?;
2558
2559    // Batch prefill working buffers
2560    let batch_hidden_bytes = hidden * 4; // per-token
2561    let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
2562    let batch_gate_up_bytes = 2 * intermediate * 4;
2563    let batch_intermediate_bytes = intermediate * 4;
2564    writeln!(
2565        code,
2566        "        // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
2567    )?;
2568    writeln!(
2569        code,
2570        "        let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
2571    )?;
2572    writeln!(
2573        code,
2574        "        let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
2575    )?;
2576    writeln!(
2577        code,
2578        "        let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
2579    )?;
2580    writeln!(
2581        code,
2582        "        let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
2583    )?;
2584    writeln!(
2585        code,
2586        "        let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
2587    )?;
2588    writeln!(
2589        code,
2590        "        let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
2591    )?;
2592    writeln!(
2593        code,
2594        "        let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
2595    )?;
2596    writeln!(
2597        code,
2598        "        let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
2599    )?;
2600    writeln!(
2601        code,
2602        "        let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
2603    )?;
2604    writeln!(
2605        code,
2606        "        let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
2607    )?;
2608    writeln!(code)?;
2609
2610    // KV cache buffers
2611    writeln!(code, "        // KV cache buffers (per-layer)")?;
2612    writeln!(
2613        code,
2614        "        let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
2615    )?;
2616    writeln!(
2617        code,
2618        "        let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
2619    )?;
2620    writeln!(code, "        for _ in 0..NUM_LAYERS {{")?;
2621    writeln!(
2622        code,
2623        "            k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
2624    )?;
2625    writeln!(
2626        code,
2627        "            v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
2628    )?;
2629    writeln!(code, "        }}")?;
2630    writeln!(code)?;
2631
2632    writeln!(code, "        Self {{")?;
2633    writeln!(code, "            device,")?;
2634    writeln!(code, "            queue,")?;
2635    writeln!(code, "            matmul_pipeline,")?;
2636    writeln!(code, "            matmul_q8_pipeline,")?;
2637    writeln!(code, "            matmul_q4_pipeline,")?;
2638    writeln!(code, "            rms_norm_pipeline,")?;
2639    writeln!(code, "            rope_pipeline,")?;
2640    writeln!(code, "            softmax_pipeline,")?;
2641    writeln!(code, "            silu_mul_pipeline,")?;
2642    writeln!(code, "            silu_mul_fused_pipeline,")?;
2643    writeln!(code, "            add_pipeline,")?;
2644    writeln!(code, "            attention_pipeline,")?;
2645    writeln!(code, "            add_inplace_pipeline,")?;
2646    writeln!(code, "            copy_pipeline,")?;
2647    writeln!(code, "            copy_offset_pipeline,")?;
2648    writeln!(code, "            matmul_batch_pipeline,")?;
2649    writeln!(code, "            matmul_q8_batch_pipeline,")?;
2650    writeln!(code, "            matmul_q8_gemm_batch_pipeline,")?;
2651    writeln!(code, "            matmul_q4_batch_pipeline,")?;
2652    writeln!(code, "            rms_norm_batch_pipeline,")?;
2653    writeln!(code, "            rope_batch_pipeline,")?;
2654    writeln!(code, "            silu_mul_fused_batch_pipeline,")?;
2655    writeln!(code, "            add_inplace_batch_pipeline,")?;
2656    writeln!(code, "            copy_embedding_batch_pipeline,")?;
2657    writeln!(code, "            attention_batch_pipeline,")?;
2658    writeln!(code, "            copy_kv_batch_pipeline,")?;
2659    writeln!(code, "            rope_qk_batch_pipeline,")?;
2660    writeln!(code, "            copy_kv_both_batch_pipeline,")?;
2661    writeln!(code, "            embed_buf,")?;
2662    writeln!(code, "            layers,")?;
2663    writeln!(code, "            norm_buf,")?;
2664    writeln!(code, "            lm_head_buf,")?;
2665    writeln!(code, "            hidden_buf,")?;
2666    writeln!(code, "            residual_buf,")?;
2667    writeln!(code, "            normed_buf,")?;
2668    writeln!(code, "            qkv_buf,")?;
2669    writeln!(code, "            attn_out_buf,")?;
2670    writeln!(code, "            attn_proj_buf,")?;
2671    writeln!(code, "            gate_up_buf,")?;
2672    writeln!(code, "            ffn_hidden_buf,")?;
2673    writeln!(code, "            ffn_out_buf,")?;
2674    writeln!(code, "            add_tmp_buf,")?;
2675    writeln!(code, "            logits_buf,")?;
2676    writeln!(code, "            batch_hidden_buf,")?;
2677    writeln!(code, "            batch_residual_buf,")?;
2678    writeln!(code, "            batch_qkv_buf,")?;
2679    writeln!(code, "            batch_attn_out_buf,")?;
2680    writeln!(code, "            batch_attn_proj_buf,")?;
2681    writeln!(code, "            batch_gate_up_buf,")?;
2682    writeln!(code, "            batch_ffn_hidden_buf,")?;
2683    writeln!(code, "            batch_ffn_out_buf,")?;
2684    writeln!(code, "            batch_tokens_buf,")?;
2685    writeln!(code, "            batch_positions_buf,")?;
2686    writeln!(code, "            k_cache,")?;
2687    writeln!(code, "            v_cache,")?;
2688    writeln!(code, "            pos: 0,")?;
2689    writeln!(code, "            prev_cmd: None,")?;
2690    writeln!(code, "        }}")?;
2691    writeln!(code, "    }}")?;
2692    writeln!(code)?;
2693
2694    // ── forward() ──
2695    writeln!(
2696        code,
2697        "    /// Run the forward pass for a single token at the current position."
2698    )?;
2699    writeln!(code, "    ///")?;
2700    writeln!(
2701        code,
2702        "    /// Returns logits over the vocabulary as a `Vec<f32>`."
2703    )?;
2704    writeln!(code, "    ///")?;
2705    writeln!(
2706        code,
2707        "    /// All GPU operations are encoded into a single command buffer and"
2708    )?;
2709    writeln!(
2710        code,
2711        "    /// committed once at the end, avoiding per-operation synchronization."
2712    )?;
2713    writeln!(
2714        code,
2715        "    pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
2716    )?;
2717    writeln!(
2718        code,
2719        "        // Wait for any pending prefill command buffer"
2720    )?;
2721    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
2722    writeln!(code, "            prev.wait_until_completed();")?;
2723    writeln!(code, "        }}")?;
2724    writeln!(code)?;
2725    writeln!(code, "        let pos = self.pos;")?;
2726    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
2727    writeln!(code)?;
2728
2729    // Single compute encoder for the entire forward pass — no blit encoder
2730    // transitions. Copy operations use compute copy kernels instead of blits.
2731    let matmul_fn = if is_q8 {
2732        "dispatch_matmul_q8"
2733    } else if is_q4 {
2734        "dispatch_matmul_q4"
2735    } else {
2736        "dispatch_matmul"
2737    };
2738
2739    writeln!(
2740        code,
2741        "        // Single compute encoder for the entire forward pass (no blit transitions)"
2742    )?;
2743    writeln!(code, "        {{")?;
2744    writeln!(
2745        code,
2746        "            let enc = cmd.new_compute_command_encoder();"
2747    )?;
2748    writeln!(code)?;
2749
2750    // 1. Embedding lookup via CPU memcpy (unified memory — zero GPU dispatch overhead)
2751    writeln!(
2752        code,
2753        "            // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
2754    )?;
2755    writeln!(
2756        code,
2757        "            // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
2758    )?;
2759    writeln!(
2760        code,
2761        "            // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
2762    )?;
2763    writeln!(
2764        code,
2765        "            // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
2766        hidden * 4,
2767    )?;
2768    writeln!(code, "            unsafe {{")?;
2769    writeln!(
2770        code,
2771        "                let embed_ptr = self.embed_buf.contents() as *const f32;"
2772    )?;
2773    writeln!(
2774        code,
2775        "                let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
2776    )?;
2777    writeln!(
2778        code,
2779        "                let residual_ptr = self.residual_buf.contents() as *mut f32;"
2780    )?;
2781    writeln!(code, "                std::ptr::copy_nonoverlapping(")?;
2782    writeln!(
2783        code,
2784        "                    embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
2785    )?;
2786    writeln!(code, "                    hidden_ptr,")?;
2787    writeln!(code, "                    HIDDEN_SIZE,")?;
2788    writeln!(code, "                );")?;
2789    writeln!(
2790        code,
2791        "                std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
2792    )?;
2793    writeln!(code, "            }}")?;
2794    writeln!(code)?;
2795
2796    // 2. Transformer layers
2797    writeln!(code, "            // 2. Transformer layers")?;
2798    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
2799    writeln!(code)?;
2800    let q_byte_offset = 0usize;
2801    let k_byte_offset = hidden * 4;
2802    let v_byte_offset = (hidden + kv_dim) * 4;
2803
2804    writeln!(
2805        code,
2806        "                // Pre-attention: rms_norm, fused QKV projection, RoPE"
2807    )?;
2808    writeln!(
2809        code,
2810        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
2811    )?;
2812    writeln!(
2813        code,
2814        "                // Fused Q+K+V matmul: single dispatch for all three projections"
2815    )?;
2816    writeln!(
2817        code,
2818        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
2819    )?;
2820    writeln!(
2821        code,
2822        "                // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
2823    )?;
2824    writeln!(
2825        code,
2826        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
2827    )?;
2828    writeln!(
2829        code,
2830        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
2831    )?;
2832    writeln!(code)?;
2833    writeln!(
2834        code,
2835        "                // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
2836    )?;
2837    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
2838    writeln!(
2839        code,
2840        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
2841    )?;
2842    writeln!(
2843        code,
2844        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
2845    )?;
2846    writeln!(code)?;
2847    writeln!(
2848        code,
2849        "                // Attention using Q from qkv_buf (offset 0)"
2850    )?;
2851    writeln!(
2852        code,
2853        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
2854    )?;
2855    writeln!(
2856        code,
2857        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
2858    )?;
2859    writeln!(
2860        code,
2861        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
2862    )?;
2863    writeln!(
2864        code,
2865        "                // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
2866    )?;
2867    writeln!(
2868        code,
2869        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
2870    )?;
2871    writeln!(
2872        code,
2873        "                // Fused gate+up matmul: single dispatch for both projections"
2874    )?;
2875    writeln!(
2876        code,
2877        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
2878    )?;
2879    writeln!(
2880        code,
2881        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
2882    )?;
2883    writeln!(
2884        code,
2885        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
2886    )?;
2887    writeln!(
2888        code,
2889        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
2890    )?;
2891    writeln!(code, "            }}")?;
2892    writeln!(code)?;
2893
2894    // 3. Final RMS norm + logits
2895    writeln!(code, "            // 3. Final RMS norm + logits projection")?;
2896    writeln!(
2897        code,
2898        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
2899    )?;
2900    writeln!(
2901        code,
2902        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
2903    )?;
2904    writeln!(code)?;
2905    writeln!(code, "            enc.end_encoding();")?;
2906    writeln!(code, "        }}")?;
2907    writeln!(code)?;
2908
2909    // 5. Single commit + wait, then read back logits
2910    writeln!(
2911        code,
2912        "        // 5. Commit all GPU work and wait for completion"
2913    )?;
2914    writeln!(code, "        cmd.commit();")?;
2915    writeln!(code, "        cmd.wait_until_completed();")?;
2916    writeln!(code)?;
2917    writeln!(code, "        // 6. Read back logits from GPU")?;
2918    writeln!(code, "        let logits = unsafe {{")?;
2919    writeln!(
2920        code,
2921        "            let ptr = self.logits_buf.contents() as *const f32;"
2922    )?;
2923    writeln!(
2924        code,
2925        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
2926    )?;
2927    writeln!(code, "        }};")?;
2928    writeln!(code)?;
2929    writeln!(code, "        self.pos += 1;")?;
2930    writeln!(code, "        logits")?;
2931    writeln!(code, "    }}")?;
2932    writeln!(code)?;
2933
2934    // ── forward_profile: instrumented forward with per-operation timing ──
2935    writeln!(
2936        code,
2937        "    /// Profiling forward pass that prints per-stage GPU timing."
2938    )?;
2939    writeln!(code, "    ///")?;
2940    writeln!(
2941        code,
2942        "    /// Each stage is committed and waited on separately so that GPU timestamps"
2943    )?;
2944    writeln!(
2945        code,
2946        "    /// accurately reflect per-operation cost. This is slower than `forward()` due"
2947    )?;
2948    writeln!(
2949        code,
2950        "    /// to the per-stage synchronization, but useful for identifying bottlenecks."
2951    )?;
2952    writeln!(
2953        code,
2954        "    pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
2955    )?;
2956    writeln!(code, "        use std::time::Instant;")?;
2957    writeln!(code)?;
2958    writeln!(
2959        code,
2960        "        // Wait for any pending prefill command buffer"
2961    )?;
2962    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
2963    writeln!(code, "            prev.wait_until_completed();")?;
2964    writeln!(code, "        }}")?;
2965    writeln!(code)?;
2966    writeln!(code, "        let pos = self.pos;")?;
2967    writeln!(code)?;
2968
2969    // Stage: embedding (CPU, no GPU)
2970    writeln!(
2971        code,
2972        "        // ── Stage: Embedding lookup (CPU via unified memory) ──"
2973    )?;
2974    writeln!(code, "        let t_embed = Instant::now();")?;
2975    writeln!(code, "        unsafe {{")?;
2976    writeln!(
2977        code,
2978        "            let embed_ptr = self.embed_buf.contents() as *const f32;"
2979    )?;
2980    writeln!(
2981        code,
2982        "            let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
2983    )?;
2984    writeln!(
2985        code,
2986        "            let residual_ptr = self.residual_buf.contents() as *mut f32;"
2987    )?;
2988    writeln!(code, "            std::ptr::copy_nonoverlapping(")?;
2989    writeln!(
2990        code,
2991        "                embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
2992    )?;
2993    writeln!(code, "                hidden_ptr,")?;
2994    writeln!(code, "                HIDDEN_SIZE,")?;
2995    writeln!(code, "            );")?;
2996    writeln!(
2997        code,
2998        "            std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
2999    )?;
3000    writeln!(code, "        }}")?;
3001    writeln!(code, "        let d_embed = t_embed.elapsed();")?;
3002    writeln!(code)?;
3003
3004    // Stage: Transformer layers (all together on GPU)
3005    writeln!(code, "        // ── Stage: Transformer layers (GPU) ──")?;
3006    writeln!(code, "        let t_layers = Instant::now();")?;
3007    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3008    writeln!(code, "        {{")?;
3009    writeln!(
3010        code,
3011        "            let enc = cmd.new_compute_command_encoder();"
3012    )?;
3013    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3014    writeln!(
3015        code,
3016        "                self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3017    )?;
3018    writeln!(
3019        code,
3020        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3021    )?;
3022    writeln!(
3023        code,
3024        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3025    )?;
3026    writeln!(
3027        code,
3028        "                self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3029    )?;
3030    writeln!(code, "                let kv_offset = pos * {kv_dim};")?;
3031    writeln!(
3032        code,
3033        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3034    )?;
3035    writeln!(
3036        code,
3037        "                self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3038    )?;
3039    writeln!(
3040        code,
3041        "                self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3042    )?;
3043    writeln!(
3044        code,
3045        "                self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3046    )?;
3047    writeln!(
3048        code,
3049        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3050    )?;
3051    writeln!(
3052        code,
3053        "                self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3054    )?;
3055    writeln!(
3056        code,
3057        "                self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3058    )?;
3059    writeln!(
3060        code,
3061        "                self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3062    )?;
3063    writeln!(
3064        code,
3065        "                self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3066    )?;
3067    writeln!(
3068        code,
3069        "                self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3070    )?;
3071    writeln!(code, "            }}")?;
3072    writeln!(code, "            enc.end_encoding();")?;
3073    writeln!(code, "        }}")?;
3074    writeln!(code, "        cmd.commit();")?;
3075    writeln!(code, "        cmd.wait_until_completed();")?;
3076    writeln!(code, "        let d_layers = t_layers.elapsed();")?;
3077    writeln!(code)?;
3078
3079    // Stage: Final norm + logits
3080    writeln!(code, "        // ── Stage: Final norm + logits (GPU) ──")?;
3081    writeln!(code, "        let t_logits = Instant::now();")?;
3082    writeln!(code, "        let cmd2 = self.queue.new_command_buffer();")?;
3083    writeln!(code, "        {{")?;
3084    writeln!(
3085        code,
3086        "            let enc = cmd2.new_compute_command_encoder();"
3087    )?;
3088    writeln!(
3089        code,
3090        "            self.dispatch_rms_norm(&enc, &self.norm_buf);"
3091    )?;
3092    writeln!(
3093        code,
3094        "            self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3095    )?;
3096    writeln!(code, "            enc.end_encoding();")?;
3097    writeln!(code, "        }}")?;
3098    writeln!(code, "        cmd2.commit();")?;
3099    writeln!(code, "        cmd2.wait_until_completed();")?;
3100    writeln!(code, "        let d_logits = t_logits.elapsed();")?;
3101    writeln!(code)?;
3102
3103    // Print profile results
3104    writeln!(
3105        code,
3106        "        eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
3107    )?;
3108    writeln!(code, "            d_embed.as_secs_f64() * 1000.0,")?;
3109    writeln!(code, "            d_layers.as_secs_f64() * 1000.0,")?;
3110    writeln!(code, "            d_logits.as_secs_f64() * 1000.0,")?;
3111    writeln!(
3112        code,
3113        "            (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
3114    )?;
3115    writeln!(code)?;
3116
3117    // Read back logits
3118    writeln!(code, "        let logits = unsafe {{")?;
3119    writeln!(
3120        code,
3121        "            let ptr = self.logits_buf.contents() as *const f32;"
3122    )?;
3123    writeln!(
3124        code,
3125        "            std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3126    )?;
3127    writeln!(code, "        }};")?;
3128    writeln!(code)?;
3129    writeln!(code, "        self.pos += 1;")?;
3130    writeln!(code, "        logits")?;
3131    writeln!(code, "    }}")?;
3132    writeln!(code)?;
3133
3134    // ── forward_prefill: single-token async forward (backward compat) ──
3135    writeln!(
3136        code,
3137        "    /// Asynchronous forward pass for a single prefill token (no logits readback)."
3138    )?;
3139    writeln!(code, "    ///")?;
3140    writeln!(
3141        code,
3142        "    /// Commits the command buffer without waiting, enabling double-buffered"
3143    )?;
3144    writeln!(
3145        code,
3146        "    /// execution: GPU processes token N while CPU encodes token N+1."
3147    )?;
3148    writeln!(
3149        code,
3150        "    pub fn forward_prefill(&mut self, token_id: u32) {{"
3151    )?;
3152    writeln!(code, "        self.forward_prefill_batch(&[token_id]);")?;
3153    writeln!(code, "    }}")?;
3154    writeln!(code)?;
3155
3156    // ── forward_prefill_batch: batched prefill for multiple tokens ──
3157    // Batched matmuls for QKV/O/FFN projections, sequential attention (causal dependency).
3158    let batch_matmul_fn = if is_q8 {
3159        "dispatch_matmul_q8_batch"
3160    } else if is_q4 {
3161        "dispatch_matmul_q4_batch"
3162    } else {
3163        "dispatch_matmul_batch"
3164    };
3165
3166    writeln!(
3167        code,
3168        "    /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
3169    )?;
3170    writeln!(code, "    ///")?;
3171    writeln!(
3172        code,
3173        "    /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
3174    )?;
3175    writeln!(
3176        code,
3177        "    /// of mat-vec), and batched causal attention with a single GPU dispatch."
3178    )?;
3179    writeln!(
3180        code,
3181        "    /// This provides significant speedup during prompt prefill."
3182    )?;
3183    writeln!(
3184        code,
3185        "    pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
3186    )?;
3187    writeln!(code, "        let m = tokens.len().min(MAX_BATCH_SIZE);")?;
3188    writeln!(code, "        if m == 0 {{ return; }}")?;
3189    writeln!(code, "        let start_pos = self.pos;")?;
3190    writeln!(code)?;
3191    writeln!(code, "        // Wait for any pending command buffer")?;
3192    writeln!(code, "        if let Some(prev) = self.prev_cmd.take() {{")?;
3193    writeln!(code, "            prev.wait_until_completed();")?;
3194    writeln!(code, "        }}")?;
3195    writeln!(code)?;
3196
3197    // Upload token IDs and positions to GPU
3198    writeln!(
3199        code,
3200        "        // Upload token IDs and positions to GPU buffers"
3201    )?;
3202    writeln!(code, "        unsafe {{")?;
3203    writeln!(
3204        code,
3205        "            let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
3206    )?;
3207    writeln!(
3208        code,
3209        "            let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
3210    )?;
3211    writeln!(code, "            for i in 0..m {{")?;
3212    writeln!(code, "                *tok_ptr.add(i) = tokens[i];")?;
3213    writeln!(
3214        code,
3215        "                *pos_ptr.add(i) = (start_pos + i) as u32;"
3216    )?;
3217    writeln!(code, "            }}")?;
3218    writeln!(code, "        }}")?;
3219    writeln!(code)?;
3220
3221    writeln!(code, "        let cmd = self.queue.new_command_buffer();")?;
3222    writeln!(code, "        {{")?;
3223    writeln!(
3224        code,
3225        "            let enc = cmd.new_compute_command_encoder();"
3226    )?;
3227    writeln!(code)?;
3228
3229    // 1. Batch embedding lookup
3230    writeln!(
3231        code,
3232        "            // 1. Batch embedding lookup: copy all token embeddings at once"
3233    )?;
3234    writeln!(
3235        code,
3236        "            self.dispatch_copy_embedding_batch(&enc, m);"
3237    )?;
3238    // Copy batch_hidden -> batch_residual
3239    writeln!(
3240        code,
3241        "            self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
3242    )?;
3243    writeln!(code)?;
3244
3245    // 2. Transformer layers
3246    writeln!(code, "            // 2. Transformer layers")?;
3247    writeln!(code, "            for layer in 0..NUM_LAYERS {{")?;
3248    writeln!(code)?;
3249
3250    // Batch RMS norm: residual -> hidden (batched)
3251    writeln!(
3252        code,
3253        "                // Batch RMS norm: batch_residual -> batch_hidden"
3254    )?;
3255    writeln!(
3256        code,
3257        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
3258    )?;
3259
3260    // Batch QKV matmul
3261    writeln!(
3262        code,
3263        "                // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
3264    )?;
3265    writeln!(
3266        code,
3267        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
3268    )?;
3269    writeln!(code)?;
3270
3271    // Fused RoPE on Q+K portions in a single dispatch
3272    let k_float_offset = hidden;
3273    writeln!(
3274        code,
3275        "                // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
3276    )?;
3277    writeln!(
3278        code,
3279        "                self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
3280    )?;
3281    writeln!(code)?;
3282
3283    // Fused KV cache update: copy both K and V in a single dispatch
3284    let v_float_offset = hidden + kv_dim;
3285    writeln!(
3286        code,
3287        "                // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
3288    )?;
3289    writeln!(
3290        code,
3291        "                self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
3292    )?;
3293    writeln!(code)?;
3294
3295    // Batched causal attention: ONE dispatch for all M tokens
3296    writeln!(
3297        code,
3298        "                // Batched causal attention: one dispatch for all M tokens"
3299    )?;
3300    writeln!(
3301        code,
3302        "                self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
3303    )?;
3304    writeln!(code)?;
3305
3306    // Batched O projection: [M, hidden] x [hidden, hidden]^T -> [M, hidden]
3307    writeln!(code, "                // Batched O projection")?;
3308    writeln!(
3309        code,
3310        "                self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
3311    )?;
3312    writeln!(code)?;
3313
3314    // Batch add: residual += attn_proj for all tokens
3315    writeln!(
3316        code,
3317        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
3318    )?;
3319    writeln!(code)?;
3320
3321    // Batch FFN
3322    writeln!(
3323        code,
3324        "                // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
3325    )?;
3326    writeln!(
3327        code,
3328        "                self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
3329    )?;
3330    writeln!(
3331        code,
3332        "                self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
3333    )?;
3334    writeln!(
3335        code,
3336        "                self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
3337    )?;
3338    writeln!(
3339        code,
3340        "                self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
3341    )?;
3342    writeln!(
3343        code,
3344        "                self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
3345    )?;
3346    writeln!(code, "            }}")?;
3347    writeln!(code)?;
3348
3349    // Copy last token's residual to single-token residual_buf for next forward() call
3350    writeln!(
3351        code,
3352        "            // Copy last token's residual to single-token buffer for subsequent forward()"
3353    )?;
3354    writeln!(
3355        code,
3356        "            self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
3357    )?;
3358    writeln!(code)?;
3359    writeln!(code, "            enc.end_encoding();")?;
3360    writeln!(code, "        }}")?;
3361    writeln!(code)?;
3362
3363    writeln!(code, "        cmd.commit();")?;
3364    writeln!(code, "        self.prev_cmd = Some(cmd.to_owned());")?;
3365    writeln!(code, "        self.pos += m;")?;
3366    writeln!(code, "    }}")?;
3367    writeln!(code)?;
3368
3369    // ── reset() — rewind KV cache position for new inference requests ──
3370    writeln!(
3371        code,
3372        "    /// Reset the model state for a new inference request."
3373    )?;
3374    writeln!(code, "    pub fn reset(&mut self) {{")?;
3375    writeln!(code, "        self.pos = 0;")?;
3376    writeln!(code, "        self.prev_cmd = None;")?;
3377    writeln!(code, "    }}")?;
3378    writeln!(code)?;
3379
3380    // ── Private dispatch helpers (all take a shared compute encoder) ──
3381    writeln!(
3382        code,
3383        "    // ── Dispatch helpers (append to a shared compute command encoder) ──"
3384    )?;
3385    writeln!(
3386        code,
3387        "    // These methods set pipeline state + buffers + dispatch on an existing"
3388    )?;
3389    writeln!(
3390        code,
3391        "    // encoder, avoiding per-operation encoder creation overhead."
3392    )?;
3393    writeln!(code)?;
3394
3395    // dispatch_rms_norm
3396    writeln!(
3397        code,
3398        "    /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
3399    )?;
3400    writeln!(
3401        code,
3402        "    fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
3403    )?;
3404    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
3405    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
3406    writeln!(
3407        code,
3408        "        enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
3409    )?;
3410    writeln!(
3411        code,
3412        "        enc.set_buffer(0, Some(&self.residual_buf), 0);"
3413    )?;
3414    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
3415    writeln!(
3416        code,
3417        "        enc.set_buffer(2, Some(&self.hidden_buf), 0);"
3418    )?;
3419    writeln!(
3420        code,
3421        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
3422    )?;
3423    writeln!(
3424        code,
3425        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
3426    )?;
3427    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3428    writeln!(
3429        code,
3430        "        let grid_size = MTLSize::new(1, 1, 1);  // single threadgroup"
3431    )?;
3432    writeln!(
3433        code,
3434        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3435    )?;
3436    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3437    writeln!(code, "    }}")?;
3438    writeln!(code)?;
3439
3440    // dispatch_matmul
3441    writeln!(
3442        code,
3443        "    /// Dispatch matrix-vector multiply: weight * input -> output."
3444    )?;
3445    writeln!(
3446        code,
3447        "    fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
3448    )?;
3449    writeln!(code, "        let r: u32 = rows as u32;")?;
3450    writeln!(code, "        let c: u32 = cols as u32;")?;
3451    writeln!(
3452        code,
3453        "        enc.set_compute_pipeline_state(&self.matmul_pipeline);"
3454    )?;
3455    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
3456    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
3457    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
3458    writeln!(
3459        code,
3460        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
3461    )?;
3462    writeln!(
3463        code,
3464        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
3465    )?;
3466    writeln!(
3467        code,
3468        "        // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
3469    )?;
3470    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3471    writeln!(code, "        let num_tg = ((rows + 63) / 64) as u64;")?;
3472    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
3473    writeln!(
3474        code,
3475        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3476    )?;
3477    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3478    writeln!(code, "    }}")?;
3479    writeln!(code)?;
3480
3481    // dispatch_matmul_q8
3482    writeln!(
3483        code,
3484        "    /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
3485    )?;
3486    writeln!(
3487        code,
3488        "    /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
3489    )?;
3490    writeln!(
3491        code,
3492        "    fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
3493    )?;
3494    writeln!(code, "        let r: u32 = rows as u32;")?;
3495    writeln!(code, "        let c: u32 = cols as u32;")?;
3496    writeln!(
3497        code,
3498        "        enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
3499    )?;
3500    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
3501    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
3502    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
3503    writeln!(
3504        code,
3505        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
3506    )?;
3507    writeln!(
3508        code,
3509        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
3510    )?;
3511    writeln!(
3512        code,
3513        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
3514    )?;
3515    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3516    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
3517    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
3518    writeln!(
3519        code,
3520        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3521    )?;
3522    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3523    writeln!(code, "    }}")?;
3524    writeln!(code)?;
3525
3526    // dispatch_matmul_q4
3527    writeln!(
3528        code,
3529        "    /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
3530    )?;
3531    writeln!(
3532        code,
3533        "    /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
3534    )?;
3535    writeln!(
3536        code,
3537        "    fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
3538    )?;
3539    writeln!(code, "        let r: u32 = rows as u32;")?;
3540    writeln!(code, "        let c: u32 = cols as u32;")?;
3541    writeln!(
3542        code,
3543        "        enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
3544    )?;
3545    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
3546    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
3547    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
3548    writeln!(
3549        code,
3550        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
3551    )?;
3552    writeln!(
3553        code,
3554        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
3555    )?;
3556    writeln!(
3557        code,
3558        "        // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
3559    )?;
3560    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3561    writeln!(code, "        let num_tg = ((rows + 31) / 32) as u64;")?;
3562    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
3563    writeln!(
3564        code,
3565        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3566    )?;
3567    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3568    writeln!(code, "    }}")?;
3569    writeln!(code)?;
3570
3571    // dispatch_rope
3572    writeln!(code, "    /// Dispatch RoPE on a buffer in-place.")?;
3573    writeln!(
3574        code,
3575        "    fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
3576    )?;
3577    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
3578    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
3579    writeln!(code, "        let p: u32 = pos as u32;")?;
3580    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
3581    writeln!(
3582        code,
3583        "        let total_pairs = num_heads * (head_dim / 2);"
3584    )?;
3585    writeln!(
3586        code,
3587        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
3588    )?;
3589    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
3590    writeln!(
3591        code,
3592        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
3593    )?;
3594    writeln!(
3595        code,
3596        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
3597    )?;
3598    writeln!(
3599        code,
3600        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
3601    )?;
3602    writeln!(
3603        code,
3604        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
3605    )?;
3606    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3607    writeln!(
3608        code,
3609        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
3610    )?;
3611    writeln!(
3612        code,
3613        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3614    )?;
3615    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3616    writeln!(code, "    }}")?;
3617    writeln!(code)?;
3618
3619    // dispatch_rope_offset
3620    writeln!(
3621        code,
3622        "    /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
3623    )?;
3624    writeln!(
3625        code,
3626        "    fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
3627    )?;
3628    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
3629    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
3630    writeln!(code, "        let p: u32 = pos as u32;")?;
3631    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
3632    writeln!(
3633        code,
3634        "        let total_pairs = num_heads * (head_dim / 2);"
3635    )?;
3636    writeln!(
3637        code,
3638        "        enc.set_compute_pipeline_state(&self.rope_pipeline);"
3639    )?;
3640    writeln!(
3641        code,
3642        "        enc.set_buffer(0, Some(buf), byte_offset as u64);"
3643    )?;
3644    writeln!(
3645        code,
3646        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
3647    )?;
3648    writeln!(
3649        code,
3650        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
3651    )?;
3652    writeln!(
3653        code,
3654        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
3655    )?;
3656    writeln!(
3657        code,
3658        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
3659    )?;
3660    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3661    writeln!(
3662        code,
3663        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
3664    )?;
3665    writeln!(
3666        code,
3667        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3668    )?;
3669    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3670    writeln!(code, "    }}")?;
3671    writeln!(code)?;
3672
3673    // dispatch_attention
3674    writeln!(
3675        code,
3676        "    /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
3677    )?;
3678    writeln!(
3679        code,
3680        "    fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
3681    )?;
3682    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
3683    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
3684    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
3685    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
3686    writeln!(
3687        code,
3688        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
3689    )?;
3690    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
3691    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
3692    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
3693    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
3694    writeln!(
3695        code,
3696        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
3697    )?;
3698    writeln!(
3699        code,
3700        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
3701    )?;
3702    writeln!(
3703        code,
3704        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
3705    )?;
3706    writeln!(
3707        code,
3708        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
3709    )?;
3710    writeln!(
3711        code,
3712        "        // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
3713    )?;
3714    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3715    writeln!(
3716        code,
3717        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
3718    )?;
3719    writeln!(
3720        code,
3721        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3722    )?;
3723    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3724    writeln!(code, "    }}")?;
3725    writeln!(code)?;
3726
3727    // dispatch_attention_offset
3728    writeln!(
3729        code,
3730        "    /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
3731    )?;
3732    writeln!(
3733        code,
3734        "    fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
3735    )?;
3736    writeln!(code, "        let sl: u32 = seq_len as u32;")?;
3737    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
3738    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
3739    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
3740    writeln!(
3741        code,
3742        "        enc.set_compute_pipeline_state(&self.attention_pipeline);"
3743    )?;
3744    writeln!(
3745        code,
3746        "        enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
3747    )?;
3748    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
3749    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
3750    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
3751    writeln!(
3752        code,
3753        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
3754    )?;
3755    writeln!(
3756        code,
3757        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
3758    )?;
3759    writeln!(
3760        code,
3761        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
3762    )?;
3763    writeln!(
3764        code,
3765        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
3766    )?;
3767    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3768    writeln!(
3769        code,
3770        "        let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
3771    )?;
3772    writeln!(
3773        code,
3774        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3775    )?;
3776    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3777    writeln!(code, "    }}")?;
3778    writeln!(code)?;
3779
3780    // dispatch_silu_mul
3781    writeln!(code, "    /// Dispatch fused SiLU-multiply kernel.")?;
3782    writeln!(
3783        code,
3784        "    fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
3785    )?;
3786    writeln!(code, "        let count: u32 = n as u32;")?;
3787    writeln!(
3788        code,
3789        "        enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
3790    )?;
3791    writeln!(code, "        enc.set_buffer(0, Some(gate), 0);")?;
3792    writeln!(code, "        enc.set_buffer(1, Some(up), 0);")?;
3793    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
3794    writeln!(
3795        code,
3796        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
3797    )?;
3798    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3799    writeln!(
3800        code,
3801        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
3802    )?;
3803    writeln!(
3804        code,
3805        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3806    )?;
3807    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3808    writeln!(code, "    }}")?;
3809    writeln!(code)?;
3810
3811    // dispatch_silu_mul_fused
3812    writeln!(
3813        code,
3814        "    /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
3815    )?;
3816    writeln!(
3817        code,
3818        "    /// gate_up_buf contains [gate(n), up(n)] contiguously."
3819    )?;
3820    writeln!(
3821        code,
3822        "    fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
3823    )?;
3824    writeln!(code, "        let count: u32 = n as u32;")?;
3825    writeln!(
3826        code,
3827        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
3828    )?;
3829    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
3830    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
3831    writeln!(
3832        code,
3833        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
3834    )?;
3835    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3836    writeln!(
3837        code,
3838        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
3839    )?;
3840    writeln!(
3841        code,
3842        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3843    )?;
3844    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3845    writeln!(code, "    }}")?;
3846    writeln!(code)?;
3847
3848    // dispatch_copy (simple src -> dst copy via compute kernel)
3849    writeln!(
3850        code,
3851        "    /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
3852    )?;
3853    writeln!(
3854        code,
3855        "    fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
3856    )?;
3857    writeln!(code, "        let n: u32 = count as u32;")?;
3858    writeln!(
3859        code,
3860        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
3861    )?;
3862    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
3863    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
3864    writeln!(
3865        code,
3866        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
3867    )?;
3868    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3869    writeln!(
3870        code,
3871        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
3872    )?;
3873    writeln!(
3874        code,
3875        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3876    )?;
3877    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3878    writeln!(code, "    }}")?;
3879    writeln!(code)?;
3880
3881    // dispatch_copy_offset (copy from src[src_offset..] -> dst)
3882    writeln!(
3883        code,
3884        "    /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
3885    )?;
3886    writeln!(
3887        code,
3888        "    /// Used for embedding table lookup (copy a specific row)."
3889    )?;
3890    writeln!(
3891        code,
3892        "    fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
3893    )?;
3894    writeln!(code, "        let off: u32 = src_offset as u32;")?;
3895    writeln!(code, "        let n: u32 = count as u32;")?;
3896    writeln!(
3897        code,
3898        "        enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
3899    )?;
3900    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
3901    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
3902    writeln!(
3903        code,
3904        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
3905    )?;
3906    writeln!(
3907        code,
3908        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
3909    )?;
3910    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3911    writeln!(
3912        code,
3913        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
3914    )?;
3915    writeln!(
3916        code,
3917        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3918    )?;
3919    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3920    writeln!(code, "    }}")?;
3921    writeln!(code)?;
3922
3923    // dispatch_copy_from_offset (copy from src at byte offset to dst at float offset)
3924    writeln!(
3925        code,
3926        "    /// Dispatch copy from source at byte offset to destination at float offset."
3927    )?;
3928    writeln!(
3929        code,
3930        "    /// Used for KV cache updates from fused QKV buffer."
3931    )?;
3932    writeln!(
3933        code,
3934        "    fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
3935    )?;
3936    writeln!(code, "        let n: u32 = count as u32;")?;
3937    writeln!(
3938        code,
3939        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
3940    )?;
3941    writeln!(
3942        code,
3943        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
3944    )?;
3945    writeln!(
3946        code,
3947        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
3948    )?;
3949    writeln!(
3950        code,
3951        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
3952    )?;
3953    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3954    writeln!(
3955        code,
3956        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
3957    )?;
3958    writeln!(
3959        code,
3960        "        enc.dispatch_thread_groups(grid_size, tg_size);"
3961    )?;
3962    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3963    writeln!(code, "    }}")?;
3964    writeln!(code)?;
3965
3966    // dispatch_copy_to_offset (copy src -> dst[dst_offset..])
3967    writeln!(
3968        code,
3969        "    /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
3970    )?;
3971    writeln!(
3972        code,
3973        "    /// Used for KV cache updates (write to a specific position in the cache)."
3974    )?;
3975    writeln!(
3976        code,
3977        "    fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
3978    )?;
3979    writeln!(code, "        let n: u32 = count as u32;")?;
3980    writeln!(
3981        code,
3982        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
3983    )?;
3984    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
3985    writeln!(
3986        code,
3987        "        enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
3988    )?;
3989    writeln!(
3990        code,
3991        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
3992    )?;
3993    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
3994    writeln!(
3995        code,
3996        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
3997    )?;
3998    writeln!(
3999        code,
4000        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4001    )?;
4002    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4003    writeln!(code, "    }}")?;
4004    writeln!(code)?;
4005
4006    // dispatch_add_inplace (residual connection, no blit needed)
4007    writeln!(
4008        code,
4009        "    /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
4010    )?;
4011    writeln!(
4012        code,
4013        "    fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
4014    )?;
4015    writeln!(code, "        let count: u32 = n as u32;")?;
4016    writeln!(
4017        code,
4018        "        enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
4019    )?;
4020    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
4021    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
4022    writeln!(
4023        code,
4024        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4025    )?;
4026    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4027    writeln!(
4028        code,
4029        "        let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4030    )?;
4031    writeln!(
4032        code,
4033        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4034    )?;
4035    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4036    writeln!(code, "    }}")?;
4037    writeln!(code)?;
4038
4039    // ── Batched prefill dispatch helpers ──
4040    writeln!(code, "    // ── Batched prefill dispatch helpers ──")?;
4041    writeln!(code)?;
4042
4043    // dispatch_copy_embedding_batch
4044    writeln!(
4045        code,
4046        "    /// Dispatch batched embedding lookup: copy M token embeddings at once."
4047    )?;
4048    writeln!(
4049        code,
4050        "    fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
4051    )?;
4052    writeln!(code, "        let dim: u32 = HIDDEN_SIZE as u32;")?;
4053    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4054    writeln!(
4055        code,
4056        "        enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
4057    )?;
4058    writeln!(code, "        enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
4059    writeln!(
4060        code,
4061        "        enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
4062    )?;
4063    writeln!(
4064        code,
4065        "        enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
4066    )?;
4067    writeln!(
4068        code,
4069        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
4070    )?;
4071    writeln!(
4072        code,
4073        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4074    )?;
4075    writeln!(code, "        let total = num_tokens * HIDDEN_SIZE;")?;
4076    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4077    writeln!(
4078        code,
4079        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4080    )?;
4081    writeln!(
4082        code,
4083        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4084    )?;
4085    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4086    writeln!(code, "    }}")?;
4087    writeln!(code)?;
4088
4089    // dispatch_rms_norm_batch
4090    writeln!(
4091        code,
4092        "    /// Dispatch batched RMS norm: normalizes M vectors at once."
4093    )?;
4094    writeln!(
4095        code,
4096        "    fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
4097    )?;
4098    writeln!(code, "        let n: u32 = HIDDEN_SIZE as u32;")?;
4099    writeln!(code, "        let eps: f32 = RMS_NORM_EPS;")?;
4100    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4101    writeln!(
4102        code,
4103        "        enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
4104    )?;
4105    writeln!(code, "        enc.set_buffer(0, Some(input), 0);")?;
4106    writeln!(code, "        enc.set_buffer(1, Some(weight), 0);")?;
4107    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4108    writeln!(
4109        code,
4110        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4111    )?;
4112    writeln!(
4113        code,
4114        "        enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4115    )?;
4116    writeln!(
4117        code,
4118        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4119    )?;
4120    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4121    writeln!(
4122        code,
4123        "        let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
4124    )?;
4125    writeln!(
4126        code,
4127        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4128    )?;
4129    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4130    writeln!(code, "    }}")?;
4131    writeln!(code)?;
4132
4133    // dispatch_matmul_batch (f32)
4134    writeln!(
4135        code,
4136        "    /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4137    )?;
4138    writeln!(
4139        code,
4140        "    fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4141    )?;
4142    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4143    writeln!(code, "        let r: u32 = rows as u32;")?;
4144    writeln!(code, "        let c: u32 = cols as u32;")?;
4145    writeln!(
4146        code,
4147        "        enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
4148    )?;
4149    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4150    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4151    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4152    writeln!(
4153        code,
4154        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4155    )?;
4156    writeln!(
4157        code,
4158        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4159    )?;
4160    writeln!(
4161        code,
4162        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4163    )?;
4164    writeln!(
4165        code,
4166        "        let row_tgs = (rows + 63) / 64;  // 64 rows per threadgroup for f32"
4167    )?;
4168    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
4169    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4170    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4171    writeln!(
4172        code,
4173        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4174    )?;
4175    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4176    writeln!(code, "    }}")?;
4177    writeln!(code)?;
4178
4179    // dispatch_matmul_q8_batch
4180    writeln!(
4181        code,
4182        "    /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4183    )?;
4184    writeln!(code, "    ///")?;
4185    writeln!(
4186        code,
4187        "    /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
4188    )?;
4189    writeln!(
4190        code,
4191        "    /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
4192    )?;
4193    writeln!(
4194        code,
4195        "    fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4196    )?;
4197    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4198    writeln!(code, "        let r: u32 = rows as u32;")?;
4199    writeln!(code, "        let c: u32 = cols as u32;")?;
4200    writeln!(
4201        code,
4202        "        // Tile size must match TOKENS_PER_TG_Q8 in shaders."
4203    )?;
4204    writeln!(code, "        const TOKENS_PER_TG_Q8: usize = 4;")?;
4205    writeln!(code, "        if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
4206    writeln!(
4207        code,
4208        "            enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
4209    )?;
4210    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
4211    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
4212    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
4213    writeln!(
4214        code,
4215        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4216    )?;
4217    writeln!(
4218        code,
4219        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4220    )?;
4221    writeln!(
4222        code,
4223        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4224    )?;
4225    writeln!(
4226        code,
4227        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
4228    )?;
4229    writeln!(
4230        code,
4231        "            let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
4232    )?;
4233    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
4234    writeln!(
4235        code,
4236        "            let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
4237    )?;
4238    writeln!(
4239        code,
4240        "            enc.dispatch_thread_groups(grid_size, tg_size);"
4241    )?;
4242    writeln!(code, "        }} else {{")?;
4243    writeln!(
4244        code,
4245        "            enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
4246    )?;
4247    writeln!(code, "            enc.set_buffer(0, Some(weight), 0);")?;
4248    writeln!(code, "            enc.set_buffer(1, Some(input), 0);")?;
4249    writeln!(code, "            enc.set_buffer(2, Some(output), 0);")?;
4250    writeln!(
4251        code,
4252        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4253    )?;
4254    writeln!(
4255        code,
4256        "            enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4257    )?;
4258    writeln!(
4259        code,
4260        "            enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4261    )?;
4262    writeln!(
4263        code,
4264        "            let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q8"
4265    )?;
4266    writeln!(
4267        code,
4268        "            let num_tg = (row_tgs * num_tokens) as u64;"
4269    )?;
4270    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
4271    writeln!(
4272        code,
4273        "            let grid_size = MTLSize::new(num_tg, 1, 1);"
4274    )?;
4275    writeln!(
4276        code,
4277        "            enc.dispatch_thread_groups(grid_size, tg_size);"
4278    )?;
4279    writeln!(code, "        }}")?;
4280    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4281    writeln!(code, "    }}")?;
4282    writeln!(code)?;
4283
4284    // dispatch_matmul_q4_batch
4285    writeln!(
4286        code,
4287        "    /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4288    )?;
4289    writeln!(
4290        code,
4291        "    fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4292    )?;
4293    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4294    writeln!(code, "        let r: u32 = rows as u32;")?;
4295    writeln!(code, "        let c: u32 = cols as u32;")?;
4296    writeln!(
4297        code,
4298        "        enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
4299    )?;
4300    writeln!(code, "        enc.set_buffer(0, Some(weight), 0);")?;
4301    writeln!(code, "        enc.set_buffer(1, Some(input), 0);")?;
4302    writeln!(code, "        enc.set_buffer(2, Some(output), 0);")?;
4303    writeln!(
4304        code,
4305        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4306    )?;
4307    writeln!(
4308        code,
4309        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4310    )?;
4311    writeln!(
4312        code,
4313        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4314    )?;
4315    writeln!(
4316        code,
4317        "        let row_tgs = (rows + 31) / 32;  // 32 rows per threadgroup for Q4"
4318    )?;
4319    writeln!(code, "        let num_tg = (row_tgs * num_tokens) as u64;")?;
4320    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4321    writeln!(code, "        let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4322    writeln!(
4323        code,
4324        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4325    )?;
4326    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4327    writeln!(code, "    }}")?;
4328    writeln!(code)?;
4329
4330    // dispatch_rope_batch
4331    writeln!(
4332        code,
4333        "    /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
4334    )?;
4335    writeln!(
4336        code,
4337        "    /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
4338    )?;
4339    writeln!(
4340        code,
4341        "    /// `row_stride` is the number of floats per token row in the batch buffer."
4342    )?;
4343    writeln!(
4344        code,
4345        "    fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
4346    )?;
4347    writeln!(code, "        let nh: u32 = num_heads as u32;")?;
4348    writeln!(code, "        let hd: u32 = head_dim as u32;")?;
4349    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4350    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4351    writeln!(
4352        code,
4353        "        let pairs_per_token = num_heads * (head_dim / 2);"
4354    )?;
4355    writeln!(
4356        code,
4357        "        let total_pairs = num_tokens * pairs_per_token;"
4358    )?;
4359    // The rope_batch kernel expects contiguous [M, num_heads * head_dim] data.
4360    // Since our batch_qkv_buf is [M, qkv_rows] and Q/K are at offsets within each row,
4361    // we need to pass the buffer at the right byte offset for each token's data.
4362    // Actually, the rope_batch kernel accesses data[token * (num_heads * head_dim) + ...],
4363    // but our layout is data[token * row_stride + data_float_offset + ...].
4364    // We need the kernel to know the row_stride. Let me adjust the kernel approach:
4365    // Since Q and K are contiguous within each token's qkv_rows, and the batch buffer
4366    // is [M, qkv_rows], we can pass the buffer at offset (data_float_offset * 4) and
4367    // use a stride parameter. But the rope_batch kernel as written expects [M, num_heads*head_dim].
4368    //
4369    // Simplest approach: use the single-token rope kernel for each token in a loop.
4370    // This is still efficient because we're dispatching all within the same command encoder.
4371    writeln!(
4372        code,
4373        "        // Apply RoPE to each token individually (different positions, non-contiguous layout)"
4374    )?;
4375    writeln!(code, "        for t in 0..num_tokens {{")?;
4376    writeln!(
4377        code,
4378        "            let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
4379    )?;
4380    writeln!(
4381        code,
4382        "            let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
4383    )?;
4384    writeln!(
4385        code,
4386        "            enc.set_compute_pipeline_state(&self.rope_pipeline);"
4387    )?;
4388    writeln!(
4389        code,
4390        "            enc.set_buffer(0, Some(buf), byte_offset as u64);"
4391    )?;
4392    writeln!(
4393        code,
4394        "            enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4395    )?;
4396    writeln!(
4397        code,
4398        "            enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4399    )?;
4400    writeln!(
4401        code,
4402        "            enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4403    )?;
4404    writeln!(
4405        code,
4406        "            enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4407    )?;
4408    writeln!(code, "            let tg_size = MTLSize::new(256, 1, 1);")?;
4409    writeln!(
4410        code,
4411        "            let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
4412    )?;
4413    writeln!(
4414        code,
4415        "            enc.dispatch_thread_groups(grid_size, tg_size);"
4416    )?;
4417    writeln!(code, "            unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4418    writeln!(code, "        }}")?;
4419    writeln!(code, "    }}")?;
4420    writeln!(code)?;
4421
4422    // dispatch_silu_mul_fused_batch
4423    writeln!(
4424        code,
4425        "    /// Dispatch batched fused SiLU-multiply for M tokens."
4426    )?;
4427    writeln!(
4428        code,
4429        "    fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
4430    )?;
4431    writeln!(code, "        let count: u32 = n as u32;")?;
4432    writeln!(code, "        let nt: u32 = num_tokens as u32;")?;
4433    writeln!(
4434        code,
4435        "        enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
4436    )?;
4437    writeln!(code, "        enc.set_buffer(0, Some(gate_up), 0);")?;
4438    writeln!(code, "        enc.set_buffer(1, Some(output), 0);")?;
4439    writeln!(
4440        code,
4441        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4442    )?;
4443    writeln!(
4444        code,
4445        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4446    )?;
4447    writeln!(code, "        let total = num_tokens * n;")?;
4448    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4449    writeln!(
4450        code,
4451        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4452    )?;
4453    writeln!(
4454        code,
4455        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4456    )?;
4457    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4458    writeln!(code, "    }}")?;
4459    writeln!(code)?;
4460
4461    // dispatch_add_inplace_batch_n (add n elements in-place)
4462    writeln!(
4463        code,
4464        "    /// Dispatch in-place add for total_n elements: a[i] += b[i]."
4465    )?;
4466    writeln!(
4467        code,
4468        "    fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
4469    )?;
4470    writeln!(code, "        let count: u32 = total_n as u32;")?;
4471    writeln!(
4472        code,
4473        "        enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
4474    )?;
4475    writeln!(code, "        enc.set_buffer(0, Some(a), 0);")?;
4476    writeln!(code, "        enc.set_buffer(1, Some(b), 0);")?;
4477    writeln!(
4478        code,
4479        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4480    )?;
4481    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4482    writeln!(
4483        code,
4484        "        let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
4485    )?;
4486    writeln!(
4487        code,
4488        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4489    )?;
4490    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4491    writeln!(code, "    }}")?;
4492    writeln!(code)?;
4493
4494    // dispatch_add_inplace_batch_copy (copy src to dst using copy_buffer kernel)
4495    writeln!(
4496        code,
4497        "    /// Copy src to dst using compute copy kernel (for batch residual init)."
4498    )?;
4499    writeln!(
4500        code,
4501        "    fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
4502    )?;
4503    writeln!(code, "        let n: u32 = count as u32;")?;
4504    writeln!(
4505        code,
4506        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4507    )?;
4508    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4509    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4510    writeln!(
4511        code,
4512        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4513    )?;
4514    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4515    writeln!(
4516        code,
4517        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4518    )?;
4519    writeln!(
4520        code,
4521        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4522    )?;
4523    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4524    writeln!(code, "    }}")?;
4525    writeln!(code)?;
4526
4527    // dispatch_copy_to_offset_bytes (copy src to dst at float offset)
4528    writeln!(
4529        code,
4530        "    /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
4531    )?;
4532    writeln!(
4533        code,
4534        "    fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4535    )?;
4536    writeln!(code, "        let n: u32 = count as u32;")?;
4537    writeln!(
4538        code,
4539        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4540    )?;
4541    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4542    writeln!(
4543        code,
4544        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4545    )?;
4546    writeln!(
4547        code,
4548        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4549    )?;
4550    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4551    writeln!(
4552        code,
4553        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4554    )?;
4555    writeln!(
4556        code,
4557        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4558    )?;
4559    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4560    writeln!(code, "    }}")?;
4561    writeln!(code)?;
4562
4563    // dispatch_copy_from_offset_bytes (copy from src at byte offset to dst at float offset)
4564    writeln!(
4565        code,
4566        "    /// Copy from src at byte offset to dst at float offset."
4567    )?;
4568    writeln!(
4569        code,
4570        "    fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4571    )?;
4572    writeln!(code, "        let n: u32 = count as u32;")?;
4573    writeln!(
4574        code,
4575        "        enc.set_compute_pipeline_state(&self.copy_pipeline);"
4576    )?;
4577    writeln!(
4578        code,
4579        "        enc.set_buffer(0, Some(src), src_byte_offset as u64);"
4580    )?;
4581    writeln!(
4582        code,
4583        "        enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4584    )?;
4585    writeln!(
4586        code,
4587        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4588    )?;
4589    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4590    writeln!(
4591        code,
4592        "        let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4593    )?;
4594    writeln!(
4595        code,
4596        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4597    )?;
4598    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4599    writeln!(code, "    }}")?;
4600    writeln!(code)?;
4601
4602    // dispatch_copy_kv_batch
4603    writeln!(
4604        code,
4605        "    /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
4606    )?;
4607    writeln!(
4608        code,
4609        "    fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
4610    )?;
4611    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
4612    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
4613    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
4614    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
4615    writeln!(code, "        let so: u32 = src_offset as u32;")?;
4616    writeln!(
4617        code,
4618        "        enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
4619    )?;
4620    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4621    writeln!(code, "        enc.set_buffer(1, Some(dst), 0);")?;
4622    writeln!(
4623        code,
4624        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
4625    )?;
4626    writeln!(
4627        code,
4628        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
4629    )?;
4630    writeln!(
4631        code,
4632        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
4633    )?;
4634    writeln!(
4635        code,
4636        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
4637    )?;
4638    writeln!(
4639        code,
4640        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
4641    )?;
4642    writeln!(code, "        let total = num_tokens * kv_dim;")?;
4643    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4644    writeln!(
4645        code,
4646        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4647    )?;
4648    writeln!(
4649        code,
4650        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4651    )?;
4652    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4653    writeln!(code, "    }}")?;
4654    writeln!(code)?;
4655
4656    // dispatch_attention_batch
4657    writeln!(
4658        code,
4659        "    /// Dispatch batched causal attention: one dispatch for all M tokens."
4660    )?;
4661    writeln!(
4662        code,
4663        "    fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
4664    )?;
4665    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
4666    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
4667    writeln!(code, "        let nh: u32 = NUM_HEADS as u32;")?;
4668    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4669    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4670    writeln!(code, "        let qs: u32 = q_stride as u32;")?;
4671    writeln!(
4672        code,
4673        "        enc.set_compute_pipeline_state(&self.attention_batch_pipeline);"
4674    )?;
4675    writeln!(code, "        enc.set_buffer(0, Some(q_buf), 0);")?;
4676    writeln!(code, "        enc.set_buffer(1, Some(k_cache), 0);")?;
4677    writeln!(code, "        enc.set_buffer(2, Some(v_cache), 0);")?;
4678    writeln!(code, "        enc.set_buffer(3, Some(output), 0);")?;
4679    writeln!(
4680        code,
4681        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
4682    )?;
4683    writeln!(
4684        code,
4685        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
4686    )?;
4687    writeln!(
4688        code,
4689        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4690    )?;
4691    writeln!(
4692        code,
4693        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4694    )?;
4695    writeln!(
4696        code,
4697        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4698    )?;
4699    writeln!(
4700        code,
4701        "        enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
4702    )?;
4703    writeln!(
4704        code,
4705        "        // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
4706    )?;
4707    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4708    writeln!(
4709        code,
4710        "        let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
4711    )?;
4712    writeln!(
4713        code,
4714        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4715    )?;
4716    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4717    writeln!(code, "    }}")?;
4718    writeln!(code)?;
4719
4720    // dispatch_rope_qk_batch — fused Q+K RoPE in a single dispatch
4721    writeln!(
4722        code,
4723        "    /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
4724    )?;
4725    writeln!(
4726        code,
4727        "    /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
4728    )?;
4729    writeln!(
4730        code,
4731        "    fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
4732    )?;
4733    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
4734    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
4735    writeln!(code, "        let nq: u32 = NUM_HEADS as u32;")?;
4736    writeln!(code, "        let nkv: u32 = NUM_KV_HEADS as u32;")?;
4737    writeln!(code, "        let hd: u32 = HEAD_DIM as u32;")?;
4738    writeln!(code, "        let qs: u32 = qkv_stride as u32;")?;
4739    writeln!(code, "        let theta: f32 = ROPE_THETA;")?;
4740    writeln!(
4741        code,
4742        "        enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
4743    )?;
4744    writeln!(code, "        enc.set_buffer(0, Some(buf), 0);")?;
4745    writeln!(
4746        code,
4747        "        enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
4748    )?;
4749    writeln!(
4750        code,
4751        "        enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
4752    )?;
4753    writeln!(
4754        code,
4755        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
4756    )?;
4757    writeln!(
4758        code,
4759        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4760    )?;
4761    writeln!(
4762        code,
4763        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4764    )?;
4765    writeln!(
4766        code,
4767        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
4768    )?;
4769    writeln!(
4770        code,
4771        "        enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4772    )?;
4773    writeln!(
4774        code,
4775        "        let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
4776    )?;
4777    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4778    writeln!(
4779        code,
4780        "        let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4781    )?;
4782    writeln!(
4783        code,
4784        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4785    )?;
4786    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4787    writeln!(code, "    }}")?;
4788    writeln!(code)?;
4789
4790    // dispatch_copy_kv_both_batch — fused K+V cache copy in a single dispatch
4791    writeln!(
4792        code,
4793        "    /// Dispatch fused K+V cache copy in one kernel launch."
4794    )?;
4795    writeln!(
4796        code,
4797        "    /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
4798    )?;
4799    writeln!(
4800        code,
4801        "    fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
4802    )?;
4803    writeln!(code, "        let m_val: u32 = num_tokens as u32;")?;
4804    writeln!(code, "        let kv: u32 = kv_dim as u32;")?;
4805    writeln!(code, "        let bp: u32 = base_pos as u32;")?;
4806    writeln!(code, "        let ss: u32 = src_stride as u32;")?;
4807    writeln!(code, "        let ko: u32 = k_offset as u32;")?;
4808    writeln!(code, "        let vo: u32 = v_offset as u32;")?;
4809    writeln!(
4810        code,
4811        "        enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
4812    )?;
4813    writeln!(code, "        enc.set_buffer(0, Some(src), 0);")?;
4814    writeln!(code, "        enc.set_buffer(1, Some(k_dst), 0);")?;
4815    writeln!(code, "        enc.set_buffer(2, Some(v_dst), 0);")?;
4816    writeln!(
4817        code,
4818        "        enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
4819    )?;
4820    writeln!(
4821        code,
4822        "        enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
4823    )?;
4824    writeln!(
4825        code,
4826        "        enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
4827    )?;
4828    writeln!(
4829        code,
4830        "        enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
4831    )?;
4832    writeln!(
4833        code,
4834        "        enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
4835    )?;
4836    writeln!(
4837        code,
4838        "        enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
4839    )?;
4840    writeln!(
4841        code,
4842        "        let total = num_tokens * kv_dim * 2;  // K + V"
4843    )?;
4844    writeln!(code, "        let tg_size = MTLSize::new(256, 1, 1);")?;
4845    writeln!(
4846        code,
4847        "        let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4848    )?;
4849    writeln!(
4850        code,
4851        "        enc.dispatch_thread_groups(grid_size, tg_size);"
4852    )?;
4853    writeln!(code, "        unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4854    writeln!(code, "    }}")?;
4855
4856    writeln!(code, "}}")?;
4857    writeln!(code)?;
4858
4859    Ok(())
4860}
4861
4862fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
4863    writeln!(
4864        code,
4865        "// ── Helper functions ──────────────────────────────────"
4866    )?;
4867    writeln!(code)?;
4868    writeln!(
4869        code,
4870        "/// Create a compute pipeline from a named function in the Metal library."
4871    )?;
4872    writeln!(
4873        code,
4874        "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
4875    )?;
4876    writeln!(
4877        code,
4878        "    let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
4879    )?;
4880    writeln!(
4881        code,
4882        "    device.new_compute_pipeline_state_with_function(&func)"
4883    )?;
4884    writeln!(
4885        code,
4886        "        .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
4887    )?;
4888    writeln!(code, "}}")?;
4889    writeln!(code)?;
4890
4891    Ok(())
4892}
4893
4894// ---------------------------------------------------------------------------
4895// main.rs generation
4896// ---------------------------------------------------------------------------
4897
4898fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
4899    let _sanitized = sanitize_name(model_name);
4900    let _vocab = config.vocab_size;
4901
4902    let mut code = String::with_capacity(16 * 1024);
4903    writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
4904    writeln!(
4905        code,
4906        "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
4907    )?;
4908    writeln!(code)?;
4909    writeln!(code, "mod model;")?;
4910    writeln!(code)?;
4911    writeln!(code, "use std::io::Write;")?;
4912    writeln!(code, "use std::time::Instant;")?;
4913    writeln!(code, "use serde::Deserialize;")?;
4914    writeln!(code)?;
4915
4916    // -- main function --
4917    writeln!(code, "fn main() {{")?;
4918    writeln!(
4919        code,
4920        "    let args: Vec<String> = std::env::args().collect();"
4921    )?;
4922    writeln!(code)?;
4923    writeln!(
4924        code,
4925        "    // Detect --serve mode (only requires weights + tokenizer)"
4926    )?;
4927    writeln!(
4928        code,
4929        "    let serve_mode = args.iter().any(|a| a == \"--serve\");"
4930    )?;
4931    writeln!(code)?;
4932    writeln!(code, "    if !serve_mode && args.len() < 4 {{")?;
4933    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
4934    writeln!(code, "        eprintln!(\"       {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
4935    writeln!(code, "        std::process::exit(1);")?;
4936    writeln!(code, "    }}")?;
4937    writeln!(code)?;
4938    writeln!(code, "    if serve_mode && args.len() < 3 {{")?;
4939    writeln!(code, "        eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
4940    writeln!(code, "        std::process::exit(1);")?;
4941    writeln!(code, "    }}")?;
4942    writeln!(code)?;
4943    writeln!(code, "    let weights_path = &args[1];")?;
4944    writeln!(code, "    let tokenizer_path = &args[2];")?;
4945    writeln!(code)?;
4946    writeln!(code, "    // Parse optional flags")?;
4947    writeln!(code, "    let mut max_tokens: usize = 128;")?;
4948    writeln!(code, "    let mut port: u16 = 8080;")?;
4949    writeln!(
4950        code,
4951        "    let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
4952    )?;
4953    writeln!(
4954        code,
4955        "    let profile = args.iter().any(|a| a == \"--profile\");"
4956    )?;
4957    writeln!(code, "    let mut i = 3;")?;
4958    writeln!(code, "    while i < args.len() {{")?;
4959    writeln!(
4960        code,
4961        "        if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
4962    )?;
4963    writeln!(
4964        code,
4965        "            max_tokens = args[i + 1].parse().unwrap_or(128);"
4966    )?;
4967    writeln!(code, "            i += 2;")?;
4968    writeln!(
4969        code,
4970        "        }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
4971    )?;
4972    writeln!(
4973        code,
4974        "            port = args[i + 1].parse().unwrap_or(8080);"
4975    )?;
4976    writeln!(code, "            i += 2;")?;
4977    writeln!(code, "        }} else if args[i] == \"--serve\" {{")?;
4978    writeln!(code, "            i += 1;")?;
4979    writeln!(code, "        }} else if args[i] == \"--profile\" {{")?;
4980    writeln!(code, "            i += 1;")?;
4981    writeln!(code, "        }} else {{")?;
4982    writeln!(code, "            i += 1;")?;
4983    writeln!(code, "        }}")?;
4984    writeln!(code, "    }}")?;
4985    writeln!(code)?;
4986
4987    // -- load model (shared by both modes) --
4988    writeln!(
4989        code,
4990        "    // Memory-map weights for zero-copy loading on Apple Silicon"
4991    )?;
4992    writeln!(
4993        code,
4994        "    let weights_file = std::fs::File::open(weights_path)"
4995    )?;
4996    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
4997    writeln!(
4998        code,
4999        "    let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
5000    )?;
5001    writeln!(code)?;
5002    writeln!(code, "    // Load tokenizer")?;
5003    writeln!(
5004        code,
5005        "    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
5006    )?;
5007    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
5008    writeln!(code)?;
5009    writeln!(code, "    // Create Metal model")?;
5010    writeln!(code, "    eprintln!(\"Loading model onto Metal GPU...\");")?;
5011    writeln!(
5012        code,
5013        "    let mut model = model::MetalModel::new(&weights_mmap);"
5014    )?;
5015    writeln!(code)?;
5016
5017    // -- branch: serve vs CLI --
5018    writeln!(code, "    if serve_mode {{")?;
5019    writeln!(code, "        serve(model, tokenizer, port);")?;
5020    writeln!(code, "    }} else {{")?;
5021    writeln!(code, "        let prompt = &args[3];")?;
5022    writeln!(
5023        code,
5024        "        cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
5025    )?;
5026    writeln!(code, "    }}")?;
5027    writeln!(code, "}}")?;
5028    writeln!(code)?;
5029
5030    // -- cli_mode function --
5031    writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
5032    writeln!(code, "    // Tokenize prompt")?;
5033    writeln!(code, "    let encoding = tokenizer.encode(prompt, true)")?;
5034    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
5035    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
5036    writeln!(code)?;
5037    writeln!(
5038        code,
5039        "    // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
5040    )?;
5041    writeln!(
5042        code,
5043        "    // Uses double-buffered batch dispatch for GPU-efficient matmul."
5044    )?;
5045    writeln!(
5046        code,
5047        "    // The last token uses synchronous forward() to get logits."
5048    )?;
5049    writeln!(code, "    let prompt_len = prompt_tokens.len();")?;
5050    writeln!(code, "    let prefill_start = Instant::now();")?;
5051    writeln!(code, "    let logits = if prompt_len > 1 {{")?;
5052    writeln!(
5053        code,
5054        "        model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
5055    )?;
5056    writeln!(code, "        model.forward(prompt_tokens[prompt_len - 1])")?;
5057    writeln!(code, "    }} else {{")?;
5058    writeln!(code, "        model.forward(prompt_tokens[0])")?;
5059    writeln!(code, "    }};")?;
5060    writeln!(
5061        code,
5062        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
5063    )?;
5064    writeln!(code, "    let prefill_tokens = prompt_tokens.len();")?;
5065    writeln!(
5066        code,
5067        "    eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
5068    )?;
5069    writeln!(
5070        code,
5071        "        prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
5072    )?;
5073    writeln!(code)?;
5074    writeln!(code, "    // Generate tokens")?;
5075    writeln!(code, "    let mut next_token = argmax(&logits);")?;
5076    writeln!(code, "    let gen_start = Instant::now();")?;
5077    writeln!(code, "    let mut generated_count: usize = 0;")?;
5078    writeln!(code)?;
5079    writeln!(code, "    for _ in 0..max_tokens {{")?;
5080    writeln!(
5081        code,
5082        "        if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
5083    )?;
5084    writeln!(code, "            if !quiet {{")?;
5085    writeln!(code, "                print!(\"{{}}\", text);")?;
5086    writeln!(code, "                std::io::stdout().flush().ok();")?;
5087    writeln!(code, "            }}")?;
5088    writeln!(code, "        }}")?;
5089    writeln!(code, "        generated_count += 1;")?;
5090    writeln!(code)?;
5091    writeln!(
5092        code,
5093        "        // Use profiling forward for first token when --profile is set"
5094    )?;
5095    writeln!(
5096        code,
5097        "        let logits = if profile && generated_count == 1 {{"
5098    )?;
5099    writeln!(code, "            model.forward_profile(next_token)")?;
5100    writeln!(code, "        }} else {{")?;
5101    writeln!(code, "            model.forward(next_token)")?;
5102    writeln!(code, "        }};")?;
5103    writeln!(code, "        next_token = argmax(&logits);")?;
5104    writeln!(code)?;
5105    writeln!(code, "        // Stop on EOS (token 2 for most models)")?;
5106    writeln!(code, "        if next_token == 2 {{")?;
5107    writeln!(code, "            break;")?;
5108    writeln!(code, "        }}")?;
5109    writeln!(code)?;
5110    writeln!(
5111        code,
5112        "        // Yield between tokens to reduce sustained GPU thermal load."
5113    )?;
5114    writeln!(
5115        code,
5116        "        // On Apple Silicon, continuous GPU saturation causes thermal throttling"
5117    )?;
5118    writeln!(
5119        code,
5120        "        // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
5121    )?;
5122    writeln!(
5123        code,
5124        "        // briefly, providing a micro-break that helps sustain peak throughput."
5125    )?;
5126    writeln!(code, "        std::thread::yield_now();")?;
5127    writeln!(code, "    }}")?;
5128    writeln!(code, "    if !quiet {{")?;
5129    writeln!(code, "        println!();")?;
5130    writeln!(code, "    }}")?;
5131    writeln!(
5132        code,
5133        "    let gen_elapsed = gen_start.elapsed().as_secs_f64();"
5134    )?;
5135    writeln!(
5136        code,
5137        "    eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
5138    )?;
5139    writeln!(
5140        code,
5141        "        generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
5142    )?;
5143    writeln!(code, "}}")?;
5144    writeln!(code)?;
5145
5146    // -- argmax helper --
5147    writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
5148    writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
5149    writeln!(code, "    logits.iter()")?;
5150    writeln!(code, "        .enumerate()")?;
5151    writeln!(
5152        code,
5153        "        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
5154    )?;
5155    writeln!(code, "        .map(|(i, _)| i as u32)")?;
5156    writeln!(code, "        .unwrap_or(0)")?;
5157    writeln!(code, "}}")?;
5158    writeln!(code)?;
5159
5160    // -- Request/Response types for OpenAI API --
5161    writeln!(
5162        code,
5163        "// -----------------------------------------------------------------------"
5164    )?;
5165    writeln!(code, "// OpenAI-compatible API server")?;
5166    writeln!(
5167        code,
5168        "// -----------------------------------------------------------------------"
5169    )?;
5170    writeln!(code)?;
5171    writeln!(code, "#[derive(Deserialize)]")?;
5172    writeln!(code, "struct ChatRequest {{")?;
5173    writeln!(code, "    messages: Vec<ChatMessage>,")?;
5174    writeln!(code, "    #[serde(default)]")?;
5175    writeln!(code, "    stream: Option<bool>,")?;
5176    writeln!(code, "    #[serde(default)]")?;
5177    writeln!(code, "    max_tokens: Option<usize>,")?;
5178    writeln!(code, "    #[serde(default)]")?;
5179    writeln!(code, "    temperature: Option<f32>,")?;
5180    writeln!(code, "    #[serde(default)]")?;
5181    writeln!(code, "    model: Option<String>,")?;
5182    writeln!(code, "}}")?;
5183    writeln!(code)?;
5184    writeln!(code, "#[derive(Deserialize)]")?;
5185    writeln!(code, "struct ChatMessage {{")?;
5186    writeln!(code, "    role: String,")?;
5187    writeln!(code, "    content: String,")?;
5188    writeln!(code, "}}")?;
5189    writeln!(code)?;
5190
5191    // -- format_chat_messages --
5192    writeln!(
5193        code,
5194        "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
5195    )?;
5196    writeln!(code, "    let mut prompt = String::new();")?;
5197    writeln!(code, "    for msg in messages {{")?;
5198    writeln!(code, "        prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
5199    writeln!(code, "    }}")?;
5200    writeln!(code, "    prompt.push_str(\"<|im_start|>assistant\\n\");")?;
5201    writeln!(code, "    prompt")?;
5202    writeln!(code, "}}")?;
5203    writeln!(code)?;
5204
5205    // -- prefill helper --
5206    writeln!(
5207        code,
5208        "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
5209    )?;
5210    writeln!(code, "    let len = tokens.len();")?;
5211    writeln!(code, "    if len > 1 {{")?;
5212    writeln!(
5213        code,
5214        "        model.forward_prefill_batch(&tokens[..len - 1]);"
5215    )?;
5216    writeln!(code, "    }}")?;
5217    writeln!(code, "    model.forward(tokens[len - 1])")?;
5218    writeln!(code, "}}")?;
5219    writeln!(code)?;
5220
5221    // -- serve function --
5222    writeln!(
5223        code,
5224        "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
5225    )?;
5226    writeln!(code, "    let addr = format!(\"0.0.0.0:{{}}\", port);")?;
5227    writeln!(code, "    let server = tiny_http::Server::http(&addr)")?;
5228    writeln!(code, "        .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
5229    writeln!(
5230        code,
5231        "    eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
5232    )?;
5233    writeln!(code, "    eprintln!(\"Endpoints:\");")?;
5234    writeln!(code, "    eprintln!(\"  POST /v1/chat/completions\");")?;
5235    writeln!(code, "    eprintln!(\"  GET  /v1/models\");")?;
5236    writeln!(code, "    eprintln!(\"  GET  /health\");")?;
5237    writeln!(code)?;
5238    writeln!(code, "    for request in server.incoming_requests() {{")?;
5239    writeln!(code, "        let method = request.method().to_string();")?;
5240    writeln!(code, "        let url = request.url().to_string();")?;
5241    writeln!(code)?;
5242    writeln!(code, "        match (method.as_str(), url.as_str()) {{")?;
5243
5244    // -- POST /v1/chat/completions --
5245    writeln!(
5246        code,
5247        "            (\"POST\", \"/v1/chat/completions\") => {{"
5248    )?;
5249    writeln!(
5250        code,
5251        "                handle_chat_completion(&mut model, &tokenizer, request);"
5252    )?;
5253    writeln!(code, "            }}")?;
5254
5255    // -- GET /v1/models --
5256    writeln!(code, "            (\"GET\", \"/v1/models\") => {{")?;
5257    writeln!(code, "                let body = serde_json::json!({{")?;
5258    writeln!(code, "                    \"object\": \"list\",")?;
5259    writeln!(code, "                    \"data\": [{{")?;
5260    writeln!(code, "                        \"id\": \"forgellm-metal\",")?;
5261    writeln!(code, "                        \"object\": \"model\",")?;
5262    writeln!(code, "                        \"owned_by\": \"forgellm\"")?;
5263    writeln!(code, "                    }}]")?;
5264    writeln!(code, "                }});")?;
5265    writeln!(
5266        code,
5267        "                let resp = tiny_http::Response::from_string(body.to_string())"
5268    )?;
5269    writeln!(code, "                    .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
5270    writeln!(code, "                request.respond(resp).ok();")?;
5271    writeln!(code, "            }}")?;
5272
5273    // -- GET /health --
5274    writeln!(code, "            (\"GET\", \"/health\") => {{")?;
5275    writeln!(code, "                let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
5276    writeln!(code, "                request.respond(resp).ok();")?;
5277    writeln!(code, "            }}")?;
5278
5279    // -- 404 --
5280    writeln!(code, "            _ => {{")?;
5281    writeln!(
5282        code,
5283        "                let resp = tiny_http::Response::from_string(\"Not Found\")"
5284    )?;
5285    writeln!(code, "                    .with_status_code(404);")?;
5286    writeln!(code, "                request.respond(resp).ok();")?;
5287    writeln!(code, "            }}")?;
5288    writeln!(code, "        }}")?;
5289    writeln!(code, "    }}")?;
5290    writeln!(code, "}}")?;
5291    writeln!(code)?;
5292
5293    // -- handle_chat_completion --
5294    writeln!(code, "fn handle_chat_completion(")?;
5295    writeln!(code, "    model: &mut model::MetalModel,")?;
5296    writeln!(code, "    tokenizer: &tokenizers::Tokenizer,")?;
5297    writeln!(code, "    mut request: tiny_http::Request,")?;
5298    writeln!(code, ") {{")?;
5299    writeln!(code, "    // Read request body")?;
5300    writeln!(code, "    let mut body = String::new();")?;
5301    writeln!(
5302        code,
5303        "    if request.as_reader().read_to_string(&mut body).is_err() {{"
5304    )?;
5305    writeln!(code, "        let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
5306    writeln!(code, "            .with_status_code(400);")?;
5307    writeln!(code, "        request.respond(resp).ok();")?;
5308    writeln!(code, "        return;")?;
5309    writeln!(code, "    }}")?;
5310    writeln!(code)?;
5311    writeln!(code, "    // Parse JSON")?;
5312    writeln!(
5313        code,
5314        "    let req: ChatRequest = match serde_json::from_str(&body) {{"
5315    )?;
5316    writeln!(code, "        Ok(r) => r,")?;
5317    writeln!(code, "        Err(e) => {{")?;
5318    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
5319    writeln!(code, "                .with_status_code(400);")?;
5320    writeln!(code, "            request.respond(resp).ok();")?;
5321    writeln!(code, "            return;")?;
5322    writeln!(code, "        }}")?;
5323    writeln!(code, "    }};")?;
5324    writeln!(code)?;
5325    writeln!(
5326        code,
5327        "    let prompt = format_chat_messages(&req.messages);"
5328    )?;
5329    writeln!(
5330        code,
5331        "    let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
5332    )?;
5333    writeln!(code, "        Ok(e) => e,")?;
5334    writeln!(code, "        Err(e) => {{")?;
5335    writeln!(code, "            let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
5336    writeln!(code, "                .with_status_code(500);")?;
5337    writeln!(code, "            request.respond(resp).ok();")?;
5338    writeln!(code, "            return;")?;
5339    writeln!(code, "        }}")?;
5340    writeln!(code, "    }};")?;
5341    writeln!(code, "    let prompt_tokens = encoding.get_ids();")?;
5342    writeln!(code, "    let stream = req.stream.unwrap_or(false);")?;
5343    writeln!(code, "    let max_tokens = req.max_tokens.unwrap_or(256);")?;
5344    writeln!(
5345        code,
5346        "    let _temperature = req.temperature.unwrap_or(1.0);"
5347    )?;
5348    writeln!(code)?;
5349
5350    // -- Reset KV cache for each request --
5351    writeln!(code, "    model.reset();")?;
5352    writeln!(code)?;
5353
5354    // -- Prefill with timing --
5355    writeln!(code, "    let prefill_start = Instant::now();")?;
5356    writeln!(code, "    let logits = prefill(model, prompt_tokens);")?;
5357    writeln!(
5358        code,
5359        "    let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
5360    )?;
5361    writeln!(code, "    let prefill_count = prompt_tokens.len();")?;
5362    writeln!(code, "    let mut next_token = argmax(&logits);")?;
5363    writeln!(code)?;
5364
5365    writeln!(code, "    if stream {{")?;
5366
5367    // -- SSE streaming response --
5368    writeln!(
5369        code,
5370        "        // SSE streaming: generate tokens and build SSE body"
5371    )?;
5372    writeln!(code, "        let gen_start = Instant::now();")?;
5373    writeln!(code, "        let mut generated_count: usize = 0;")?;
5374    writeln!(code, "        let mut sse_body = String::new();")?;
5375    writeln!(code, "        for _ in 0..max_tokens {{")?;
5376    writeln!(code, "            if next_token == 2 {{ break; }}")?;
5377    writeln!(
5378        code,
5379        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
5380    )?;
5381    writeln!(
5382        code,
5383        "                let escaped = serde_json::to_string(&text).unwrap_or_default();"
5384    )?;
5385    writeln!(
5386        code,
5387        "                // escaped includes surrounding quotes, strip them"
5388    )?;
5389    writeln!(
5390        code,
5391        "                let inner = &escaped[1..escaped.len()-1];"
5392    )?;
5393    writeln!(code, "                sse_body.push_str(&format!(")?;
5394    writeln!(code, "                    \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
5395    writeln!(code, "                    inner")?;
5396    writeln!(code, "                ));")?;
5397    writeln!(code, "            }}")?;
5398    writeln!(code, "            generated_count += 1;")?;
5399    writeln!(code, "            let logits = model.forward(next_token);")?;
5400    writeln!(code, "            next_token = argmax(&logits);")?;
5401    writeln!(code, "        }}")?;
5402    writeln!(
5403        code,
5404        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
5405    )?;
5406    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
5407    writeln!(code, "        let gen_time_ms = gen_elapsed * 1000.0;")?;
5408    writeln!(code)?;
5409    writeln!(
5410        code,
5411        "        // Final chunk with finish_reason, timing, and DONE sentinel"
5412    )?;
5413    writeln!(code, "        sse_body.push_str(&format!(")?;
5414    writeln!(code, "            \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
5415    writeln!(code, "            prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
5416    writeln!(code, "        ));")?;
5417    writeln!(code)?;
5418    writeln!(
5419        code,
5420        "        let resp = tiny_http::Response::from_string(sse_body)"
5421    )?;
5422    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
5423    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
5424    writeln!(code, "        request.respond(resp).ok();")?;
5425
5426    writeln!(code, "    }} else {{")?;
5427
5428    // -- Non-streaming response --
5429    writeln!(
5430        code,
5431        "        // Non-streaming: generate all tokens, return JSON"
5432    )?;
5433    writeln!(code, "        let gen_start = Instant::now();")?;
5434    writeln!(code, "        let mut generated_count: usize = 0;")?;
5435    writeln!(code, "        let mut generated = String::new();")?;
5436    writeln!(code, "        for _ in 0..max_tokens {{")?;
5437    writeln!(code, "            if next_token == 2 {{ break; }}")?;
5438    writeln!(
5439        code,
5440        "            if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
5441    )?;
5442    writeln!(code, "                generated.push_str(&text);")?;
5443    writeln!(code, "            }}")?;
5444    writeln!(code, "            generated_count += 1;")?;
5445    writeln!(code, "            let logits = model.forward(next_token);")?;
5446    writeln!(code, "            next_token = argmax(&logits);")?;
5447    writeln!(code, "        }}")?;
5448    writeln!(
5449        code,
5450        "        let gen_elapsed = gen_start.elapsed().as_secs_f64();"
5451    )?;
5452    writeln!(code, "        let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
5453    writeln!(code)?;
5454    writeln!(code, "        let resp_json = serde_json::json!({{")?;
5455    writeln!(code, "            \"id\": \"chatcmpl-1\",")?;
5456    writeln!(code, "            \"object\": \"chat.completion\",")?;
5457    writeln!(code, "            \"choices\": [{{")?;
5458    writeln!(code, "                \"index\": 0,")?;
5459    writeln!(code, "                \"message\": {{")?;
5460    writeln!(code, "                    \"role\": \"assistant\",")?;
5461    writeln!(code, "                    \"content\": generated")?;
5462    writeln!(code, "                }},")?;
5463    writeln!(code, "                \"finish_reason\": \"stop\"")?;
5464    writeln!(code, "            }}],")?;
5465    writeln!(code, "            \"usage\": {{")?;
5466    writeln!(code, "                \"prefill_tokens\": prefill_count,")?;
5467    writeln!(
5468        code,
5469        "                \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
5470    )?;
5471    writeln!(
5472        code,
5473        "                \"generation_tokens\": generated_count,"
5474    )?;
5475    writeln!(
5476        code,
5477        "                \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
5478    )?;
5479    writeln!(code, "                \"tokens_per_sec\": gen_tok_s")?;
5480    writeln!(code, "            }}")?;
5481    writeln!(code, "        }});")?;
5482    writeln!(
5483        code,
5484        "        let resp = tiny_http::Response::from_string(resp_json.to_string())"
5485    )?;
5486    writeln!(code, "            .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
5487    writeln!(code, "        request.respond(resp).ok();")?;
5488    writeln!(code, "    }}")?;
5489    writeln!(code, "}}")?;
5490
5491    Ok(code)
5492}
5493
5494// ---------------------------------------------------------------------------
5495// Tests
5496// ---------------------------------------------------------------------------
5497
5498#[cfg(test)]
5499mod tests {
5500    use super::*;
5501    use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
5502
5503    fn minimal_config() -> ModelConfig {
5504        ModelConfig {
5505            architecture: Architecture::Llama,
5506            hidden_size: 64,
5507            intermediate_size: 128,
5508            num_layers: 2,
5509            num_attention_heads: 4,
5510            num_kv_heads: 4,
5511            head_dim: 16,
5512            vocab_size: 256,
5513            max_seq_len: 512,
5514            rms_norm_eps: 1e-5,
5515            rope_theta: 10000.0,
5516            dtype: DType::F32,
5517            sliding_window_size: None,
5518            qkv_bias: false,
5519        }
5520    }
5521
5522    fn minimal_graph() -> Graph {
5523        Graph::new("test-metal").with_config(minimal_config())
5524    }
5525
5526    #[test]
5527    fn generate_metal_project_creates_files() {
5528        let dir = tempfile::tempdir().unwrap();
5529        let graph = minimal_graph();
5530        generate_metal_project(&graph, dir.path(), "test-model").unwrap();
5531
5532        assert!(
5533            dir.path().join("Cargo.toml").exists(),
5534            "Cargo.toml should be created"
5535        );
5536        assert!(
5537            dir.path().join("src/model.rs").exists(),
5538            "src/model.rs should be created"
5539        );
5540        assert!(
5541            dir.path().join("src/main.rs").exists(),
5542            "src/main.rs should be created"
5543        );
5544        assert!(
5545            dir.path().join("shaders/kernels.metal").exists(),
5546            "shaders/kernels.metal should be created"
5547        );
5548    }
5549
5550    #[test]
5551    fn generated_cargo_toml_has_metal_dep() {
5552        let toml = generate_cargo_toml("my-model");
5553        assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
5554        assert!(
5555            toml.contains("tokenizers"),
5556            "Cargo.toml should depend on tokenizers"
5557        );
5558        assert!(
5559            toml.contains("memmap2"),
5560            "Cargo.toml should depend on memmap2"
5561        );
5562        assert!(toml.contains("half"), "Cargo.toml should depend on half");
5563    }
5564
5565    #[test]
5566    fn generated_model_rs_contains_metal_code() {
5567        let config = minimal_config();
5568        let model_rs = generate_model_rs(&config).unwrap();
5569
5570        assert!(
5571            model_rs.contains("pub struct MetalModel"),
5572            "model.rs should define MetalModel struct"
5573        );
5574        assert!(
5575            model_rs.contains("matmul_pipeline: ComputePipelineState"),
5576            "MetalModel should have matmul_pipeline field"
5577        );
5578        assert!(
5579            model_rs.contains("Device::system_default()"),
5580            "model.rs should use Metal device"
5581        );
5582        assert!(
5583            model_rs.contains("new_library_with_source"),
5584            "model.rs should compile Metal shaders"
5585        );
5586        assert!(
5587            model_rs.contains("fn new(weights: &[u8])"),
5588            "MetalModel should implement new()"
5589        );
5590        assert!(
5591            model_rs.contains("fn forward(&mut self, token_id: u32)"),
5592            "MetalModel should implement forward()"
5593        );
5594    }
5595
5596    #[test]
5597    fn generated_shaders_contain_kernel_names() {
5598        let shaders = generate_metal_shaders(&minimal_config());
5599
5600        assert!(
5601            shaders.contains("kernel void matmul_vec"),
5602            "shaders should contain matmul_vec kernel"
5603        );
5604        assert!(
5605            shaders.contains("kernel void rms_norm"),
5606            "shaders should contain rms_norm kernel"
5607        );
5608        assert!(
5609            shaders.contains("kernel void rope"),
5610            "shaders should contain rope kernel"
5611        );
5612        assert!(
5613            shaders.contains("kernel void softmax"),
5614            "shaders should contain softmax kernel"
5615        );
5616        assert!(
5617            shaders.contains("kernel void silu_mul("),
5618            "shaders should contain silu_mul kernel"
5619        );
5620        assert!(
5621            shaders.contains("kernel void silu_mul_fused"),
5622            "shaders should contain silu_mul_fused kernel"
5623        );
5624        assert!(
5625            shaders.contains("kernel void elementwise_add"),
5626            "shaders should contain elementwise_add kernel"
5627        );
5628        assert!(
5629            shaders.contains("kernel void attention"),
5630            "shaders should contain attention kernel"
5631        );
5632        assert!(
5633            shaders.contains("kernel void add_inplace"),
5634            "shaders should contain add_inplace kernel"
5635        );
5636        assert!(
5637            shaders.contains("kernel void copy_buffer"),
5638            "shaders should contain copy_buffer kernel"
5639        );
5640        assert!(
5641            shaders.contains("kernel void copy_offset"),
5642            "shaders should contain copy_offset kernel"
5643        );
5644    }
5645
5646    #[test]
5647    fn generated_shaders_use_simdgroup_features() {
5648        let shaders = generate_metal_shaders(&minimal_config());
5649
5650        assert!(
5651            shaders.contains("threadgroup_barrier"),
5652            "shaders should use threadgroup barriers"
5653        );
5654        assert!(
5655            shaders.contains("threadgroup float"),
5656            "shaders should use threadgroup shared memory"
5657        );
5658        assert!(
5659            shaders.contains("thread_index_in_threadgroup"),
5660            "shaders should use threadgroup indexing"
5661        );
5662        assert!(
5663            shaders.contains("simd_sum"),
5664            "shaders should use simd_sum for warp-level reduction"
5665        );
5666        assert!(
5667            shaders.contains("simd_max"),
5668            "attention kernel should use simd_max for cooperative softmax"
5669        );
5670        assert!(
5671            shaders.contains("thread_index_in_simdgroup"),
5672            "shaders should use simdgroup lane indexing"
5673        );
5674        assert!(
5675            shaders.contains("simdgroup_index_in_threadgroup"),
5676            "shaders should use simdgroup indexing within threadgroup"
5677        );
5678        assert!(
5679            shaders.contains("float4"),
5680            "matmul_vec should use float4 vectorized loads"
5681        );
5682    }
5683
5684    #[test]
5685    fn generated_main_rs_has_tokenizer_usage() {
5686        let config = minimal_config();
5687        let main_rs = generate_main_rs("test-model", &config).unwrap();
5688
5689        assert!(
5690            main_rs.contains("tokenizers::Tokenizer"),
5691            "main.rs should use tokenizers crate"
5692        );
5693        assert!(
5694            main_rs.contains("MetalModel::new"),
5695            "main.rs should call MetalModel::new"
5696        );
5697        assert!(
5698            main_rs.contains("model.forward"),
5699            "main.rs should call model.forward"
5700        );
5701        assert!(
5702            main_rs.contains("memmap2"),
5703            "main.rs should use memmap2 for zero-copy weight loading"
5704        );
5705    }
5706
5707    #[test]
5708    fn missing_config_returns_error() {
5709        let dir = tempfile::tempdir().unwrap();
5710        let graph = Graph::new("no-config");
5711        let result = generate_metal_project(&graph, dir.path(), "fail");
5712        assert!(
5713            matches!(result, Err(MetalCodegenError::MissingConfig)),
5714            "should fail with MissingConfig when graph has no config"
5715        );
5716    }
5717
5718    #[test]
5719    fn sanitize_name_works() {
5720        assert_eq!(sanitize_name("My Model!"), "my-model");
5721        assert_eq!(sanitize_name("test_model"), "test-model");
5722        assert_eq!(sanitize_name("simple"), "simple");
5723    }
5724
5725    #[test]
5726    fn generated_forward_uses_single_command_buffer() {
5727        let config = minimal_config();
5728        let model_rs = generate_model_rs(&config).unwrap();
5729
5730        // The forward function should create exactly one command buffer.
5731        // Use the exact signature to avoid matching forward_prefill/forward_profile.
5732        let forward_start = model_rs
5733            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
5734            .unwrap();
5735        let forward_body = &model_rs[forward_start..];
5736        // End at the next pub/private method
5737        let forward_end = forward_body
5738            .find("\n    pub fn forward_profile")
5739            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
5740            .or_else(|| forward_body.find("\n    fn dispatch_"))
5741            .unwrap_or(forward_body.len());
5742        let forward_code = &forward_body[..forward_end];
5743
5744        // Should have exactly one new_command_buffer call
5745        let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
5746        assert_eq!(
5747            cmd_buf_count, 1,
5748            "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
5749        );
5750
5751        // Should have exactly one commit call
5752        let commit_count = forward_code.matches("cmd.commit()").count();
5753        assert_eq!(
5754            commit_count, 1,
5755            "forward() should commit exactly once, found {commit_count}"
5756        );
5757
5758        // Should wait: once for cmd + possibly once for prev_cmd drain
5759        let wait_count = forward_code.matches("wait_until_completed()").count();
5760        assert!(
5761            wait_count >= 1 && wait_count <= 2,
5762            "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
5763        );
5764    }
5765
5766    #[test]
5767    fn generated_model_has_preallocated_working_buffers() {
5768        let config = minimal_config();
5769        let model_rs = generate_model_rs(&config).unwrap();
5770
5771        for buf_name in &[
5772            "normed_buf",
5773            "qkv_buf",
5774            "attn_out_buf",
5775            "attn_proj_buf",
5776            "gate_up_buf",
5777            "ffn_hidden_buf",
5778            "ffn_out_buf",
5779            "add_tmp_buf",
5780        ] {
5781            assert!(
5782                model_rs.contains(&format!("{buf_name}: Buffer")),
5783                "MetalModel should have pre-allocated {buf_name} field"
5784            );
5785        }
5786    }
5787
5788    #[test]
5789    fn generated_dispatch_helpers_take_compute_encoder_ref() {
5790        let config = minimal_config();
5791        let model_rs = generate_model_rs(&config).unwrap();
5792
5793        for method in &[
5794            "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
5795            "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
5796            "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
5797            "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
5798            "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
5799            "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
5800            "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
5801            "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
5802            "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
5803            "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
5804            "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
5805            "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
5806            "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
5807        ] {
5808            assert!(
5809                model_rs.contains(method),
5810                "model.rs should contain dispatch helper: {method}"
5811            );
5812        }
5813    }
5814
5815    #[test]
5816    fn generated_helpers_do_not_create_command_buffers_or_encoders() {
5817        let config = minimal_config();
5818        let model_rs = generate_model_rs(&config).unwrap();
5819
5820        // Find dispatch helpers section and check none create their own encoders
5821        let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
5822        let helpers_code = &model_rs[helpers_start..];
5823
5824        // None of the dispatch_ helpers should call new_command_buffer
5825        assert!(
5826            !helpers_code.contains("self.queue.new_command_buffer()"),
5827            "dispatch helpers should not create their own command buffers"
5828        );
5829
5830        // None should create their own compute encoders
5831        assert!(
5832            !helpers_code.contains("new_compute_command_encoder()"),
5833            "dispatch helpers should not create their own compute encoders"
5834        );
5835
5836        // None should call end_encoding
5837        assert!(
5838            !helpers_code.contains("end_encoding()"),
5839            "dispatch helpers should not call end_encoding"
5840        );
5841
5842        // None should call commit or wait
5843        assert!(
5844            !helpers_code.contains(".commit()"),
5845            "dispatch helpers should not commit command buffers"
5846        );
5847        assert!(
5848            !helpers_code.contains("wait_until_completed"),
5849            "dispatch helpers should not wait on command buffers"
5850        );
5851    }
5852
5853    #[test]
5854    fn generated_forward_batches_compute_encoders() {
5855        let config = minimal_config();
5856        let model_rs = generate_model_rs(&config).unwrap();
5857
5858        // Find the forward function body (exact signature to avoid matching forward_prefill/forward_profile)
5859        let forward_start = model_rs
5860            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
5861            .unwrap();
5862        let forward_body = &model_rs[forward_start..];
5863        let forward_end = forward_body
5864            .find("\n    pub fn forward_profile")
5865            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
5866            .or_else(|| forward_body.find("\n    fn dispatch_"))
5867            .unwrap_or(forward_body.len());
5868        let forward_code = &forward_body[..forward_end];
5869
5870        // Forward should not allocate new buffers
5871        assert!(
5872            !forward_code.contains("device.new_buffer"),
5873            "forward() should not allocate new buffers per call"
5874        );
5875
5876        // Forward should use a SINGLE compute encoder for the entire pass (no blit transitions).
5877        // Copy operations use compute copy kernels instead of blit encoders.
5878        let compute_encoder_count = forward_code
5879            .matches("new_compute_command_encoder()")
5880            .count();
5881        let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
5882
5883        // Single compute encoder for everything: embedding copy, all layers, final norm + logits
5884        assert_eq!(
5885            compute_encoder_count, 1,
5886            "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
5887        );
5888        assert_eq!(
5889            blit_encoder_count, 0,
5890            "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
5891        );
5892    }
5893
5894    #[test]
5895    fn generated_forward_uses_add_inplace() {
5896        let config = minimal_config();
5897        let model_rs = generate_model_rs(&config).unwrap();
5898
5899        // Should use in-place add (no blit copy-back needed)
5900        assert!(
5901            model_rs.contains("dispatch_add_inplace"),
5902            "forward() should use dispatch_add_inplace for residual connections"
5903        );
5904        assert!(
5905            model_rs.contains("add_inplace_pipeline"),
5906            "MetalModel should have add_inplace_pipeline"
5907        );
5908    }
5909
5910    fn minimal_q8_config() -> ModelConfig {
5911        ModelConfig {
5912            architecture: Architecture::Llama,
5913            hidden_size: 64,
5914            intermediate_size: 128,
5915            num_layers: 2,
5916            num_attention_heads: 4,
5917            num_kv_heads: 4,
5918            head_dim: 16,
5919            vocab_size: 256,
5920            max_seq_len: 512,
5921            rms_norm_eps: 1e-5,
5922            rope_theta: 10000.0,
5923            dtype: DType::Q8_0,
5924            sliding_window_size: None,
5925            qkv_bias: false,
5926        }
5927    }
5928
5929    #[test]
5930    fn generated_shaders_contain_q8_kernel() {
5931        let shaders = generate_metal_shaders(&minimal_config());
5932
5933        assert!(
5934            shaders.contains("kernel void matmul_vec_q8"),
5935            "shaders should contain matmul_vec_q8 kernel"
5936        );
5937        assert!(
5938            shaders.contains("device const uchar* matrix"),
5939            "matmul_vec_q8 should accept raw Q8_0 bytes"
5940        );
5941        assert!(
5942            shaders.contains("packed_short4"),
5943            "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
5944        );
5945        assert!(
5946            shaders.contains("as_type<char2>"),
5947            "matmul_vec_q8 should bitcast short lanes to char2"
5948        );
5949        assert!(
5950            shaders.contains("device const half*"),
5951            "matmul_vec_q8 should read f16 scale via half pointer"
5952        );
5953    }
5954
5955    #[test]
5956    fn generated_model_uses_fused_qkv_projections() {
5957        let config = minimal_config();
5958        let model_rs = generate_model_rs(&config).unwrap();
5959
5960        // Should have fused QKV weight in layer buffers
5961        assert!(
5962            model_rs.contains("qkv_weight: Buffer"),
5963            "LayerBuffers should have fused qkv_weight field"
5964        );
5965        // Should NOT have separate Q/K/V weight fields (check with leading whitespace to avoid substring matches)
5966        assert!(
5967            !model_rs.contains("    q_weight: Buffer"),
5968            "LayerBuffers should not have separate q_weight field"
5969        );
5970        assert!(
5971            !model_rs.contains("    k_weight: Buffer"),
5972            "LayerBuffers should not have separate k_weight field"
5973        );
5974        assert!(
5975            !model_rs.contains("    v_weight: Buffer"),
5976            "LayerBuffers should not have separate v_weight field"
5977        );
5978
5979        // Should have fused gate_up_weight
5980        assert!(
5981            model_rs.contains("gate_up_weight: Buffer"),
5982            "LayerBuffers should have fused gate_up_weight field"
5983        );
5984        // Should NOT have separate gate/up weight fields
5985        assert!(
5986            !model_rs.contains("    gate_weight: Buffer"),
5987            "LayerBuffers should not have separate gate_weight field"
5988        );
5989        assert!(
5990            !model_rs.contains("    up_weight: Buffer"),
5991            "LayerBuffers should not have separate up_weight field"
5992        );
5993
5994        // Should have fused working buffers
5995        assert!(
5996            model_rs.contains("qkv_buf: Buffer"),
5997            "MetalModel should have fused qkv_buf"
5998        );
5999        assert!(
6000            model_rs.contains("gate_up_buf: Buffer"),
6001            "MetalModel should have fused gate_up_buf"
6002        );
6003
6004        // Forward pass should use fused dispatch
6005        assert!(
6006            model_rs.contains("dispatch_silu_mul_fused"),
6007            "forward pass should use dispatch_silu_mul_fused"
6008        );
6009        assert!(
6010            model_rs.contains("dispatch_rope_offset"),
6011            "forward pass should use dispatch_rope_offset for fused QKV"
6012        );
6013        assert!(
6014            model_rs.contains("dispatch_attention_offset"),
6015            "forward pass should use dispatch_attention_offset for fused QKV"
6016        );
6017    }
6018
6019    #[test]
6020    fn q8_model_has_matmul_q8_pipeline() {
6021        let config = minimal_q8_config();
6022        let model_rs = generate_model_rs(&config).unwrap();
6023
6024        assert!(
6025            model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
6026            "MetalModel should have matmul_q8_pipeline field"
6027        );
6028        assert!(
6029            model_rs.contains("matmul_q8_pipeline,"),
6030            "MetalModel Self should include matmul_q8_pipeline"
6031        );
6032    }
6033
6034    #[test]
6035    fn q8_model_uses_dispatch_matmul_q8() {
6036        let config = minimal_q8_config();
6037        let model_rs = generate_model_rs(&config).unwrap();
6038
6039        assert!(
6040            model_rs.contains("dispatch_matmul_q8"),
6041            "Q8_0 model should use dispatch_matmul_q8 for projections"
6042        );
6043        assert!(
6044            model_rs.contains("fn dispatch_matmul_q8"),
6045            "model.rs should define dispatch_matmul_q8 method"
6046        );
6047    }
6048
6049    #[test]
6050    fn q8_model_loads_raw_bytes_not_dequantized() {
6051        let config = minimal_q8_config();
6052        let model_rs = generate_model_rs(&config).unwrap();
6053
6054        // Should NOT contain dequantization code
6055        assert!(
6056            !model_rs.contains("f16_to_f32"),
6057            "Q8_0 model should not dequantize weights to f32"
6058        );
6059        assert!(
6060            !model_rs.contains("f32_data"),
6061            "Q8_0 model should not create f32 weight data"
6062        );
6063
6064        // Should load raw Q8_0 bytes directly
6065        assert!(
6066            model_rs.contains("total_raw as u64"),
6067            "Q8_0 model should load raw bytes into Metal buffer"
6068        );
6069    }
6070
6071    #[test]
6072    fn q8_model_norms_stay_f32() {
6073        let config = minimal_q8_config();
6074        let model_rs = generate_model_rs(&config).unwrap();
6075
6076        // Norm weights should still use f32 buffers
6077        assert!(
6078            model_rs.contains("let attn_norm = next_f32_buffer"),
6079            "attn_norm should use f32 buffer even for Q8_0 models"
6080        );
6081        assert!(
6082            model_rs.contains("let ffn_norm = next_f32_buffer"),
6083            "ffn_norm should use f32 buffer even for Q8_0 models"
6084        );
6085        assert!(
6086            model_rs.contains("let norm_buf = next_f32_buffer"),
6087            "final norm should use f32 buffer even for Q8_0 models"
6088        );
6089    }
6090
6091    #[test]
6092    fn q8_model_uses_fused_weight_loading() {
6093        let config = minimal_q8_config();
6094        let model_rs = generate_model_rs(&config).unwrap();
6095
6096        // Should use fused Q8 buffer loading for QKV
6097        assert!(
6098            model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
6099            "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
6100        );
6101        // Should use fused Q8 buffer loading for gate+up
6102        assert!(
6103            model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
6104            "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
6105        );
6106        // Should still use regular q8 buffer for individual weights
6107        assert!(
6108            model_rs.contains("let o_weight = next_q8_buffer"),
6109            "Q8_0 model should use next_q8_buffer for O weight"
6110        );
6111        assert!(
6112            model_rs.contains("let down_weight = next_q8_buffer"),
6113            "Q8_0 model should use next_q8_buffer for down weight"
6114        );
6115    }
6116
6117    #[test]
6118    fn f32_model_does_not_use_q8_dispatch() {
6119        let config = minimal_config();
6120        let model_rs = generate_model_rs(&config).unwrap();
6121
6122        // f32 model should NOT use Q8 dispatch in forward or forward_prefill
6123        let forward_start = model_rs
6124            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6125            .unwrap();
6126        let forward_body = &model_rs[forward_start..];
6127        let forward_end = forward_body
6128            .find("\n    fn dispatch_")
6129            .unwrap_or(forward_body.len());
6130        let forward_code = &forward_body[..forward_end];
6131
6132        assert!(
6133            !forward_code.contains("dispatch_matmul_q8"),
6134            "f32 model forward should not use dispatch_matmul_q8"
6135        );
6136    }
6137
6138    #[test]
6139    fn q8_dispatch_helper_takes_compute_encoder_ref() {
6140        let config = minimal_q8_config();
6141        let model_rs = generate_model_rs(&config).unwrap();
6142
6143        assert!(
6144            model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
6145            "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
6146        );
6147    }
6148
6149    #[test]
6150    fn generated_model_has_double_buffered_prefill() {
6151        let config = minimal_config();
6152        let model_rs = generate_model_rs(&config).unwrap();
6153
6154        // MetalModel should have prev_cmd field for double-buffered prefill
6155        assert!(
6156            model_rs.contains("prev_cmd: Option<CommandBuffer>"),
6157            "MetalModel should have prev_cmd field for double-buffered prefill"
6158        );
6159
6160        // Should have forward_prefill method
6161        assert!(
6162            model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
6163            "MetalModel should have forward_prefill method"
6164        );
6165
6166        // forward() should drain prev_cmd at the start
6167        assert!(
6168            model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
6169            "forward() should drain prev_cmd from previous prefill"
6170        );
6171    }
6172
6173    #[test]
6174    fn generated_main_rs_uses_forward_prefill_for_prompt() {
6175        let config = minimal_config();
6176        let main_rs = generate_main_rs("test-model", &config).unwrap();
6177
6178        assert!(
6179            main_rs.contains("forward_prefill"),
6180            "main.rs should use forward_prefill for intermediate prompt tokens"
6181        );
6182        assert!(
6183            main_rs.contains("double-buffered"),
6184            "main.rs should document double-buffered prefill"
6185        );
6186    }
6187
6188    #[test]
6189    fn generated_shaders_q8_uses_wide_vectorized_loads() {
6190        let shaders = generate_metal_shaders(&minimal_config());
6191
6192        assert!(
6193            shaders.contains("packed_short4"),
6194            "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
6195        );
6196        assert!(
6197            shaders.contains("d0[0]"),
6198            "matmul_vec_q8 should index the wide pointer for row 0"
6199        );
6200        assert!(
6201            shaders.contains("as_type<char2>"),
6202            "matmul_vec_q8 should bitcast short lanes to char2"
6203        );
6204        assert!(
6205            shaders.contains("dot("),
6206            "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
6207        );
6208    }
6209
6210    // ── Q4_0 tests ──────────────────────────────────────────────────────
6211
6212    fn minimal_q4_config() -> ModelConfig {
6213        ModelConfig {
6214            architecture: Architecture::Llama,
6215            hidden_size: 64,
6216            intermediate_size: 128,
6217            num_layers: 2,
6218            num_attention_heads: 4,
6219            num_kv_heads: 4,
6220            head_dim: 16,
6221            vocab_size: 256,
6222            max_seq_len: 512,
6223            rms_norm_eps: 1e-5,
6224            rope_theta: 10000.0,
6225            dtype: DType::Q4_0,
6226            sliding_window_size: None,
6227            qkv_bias: false,
6228        }
6229    }
6230
6231    #[test]
6232    fn generated_shaders_contain_q4_kernel() {
6233        let shaders = generate_metal_shaders(&minimal_config());
6234
6235        assert!(
6236            shaders.contains("kernel void matmul_vec_q4"),
6237            "shaders should contain matmul_vec_q4 kernel"
6238        );
6239        assert!(
6240            shaders.contains("Q4_ROWS_PER_TG"),
6241            "shaders should define Q4_ROWS_PER_TG constant"
6242        );
6243        assert!(
6244            shaders.contains("Q4_ROWS_PER_SG"),
6245            "shaders should define Q4_ROWS_PER_SG constant"
6246        );
6247    }
6248
6249    #[test]
6250    fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
6251        let shaders = generate_metal_shaders(&minimal_config());
6252
6253        // Q4_0 kernel should use uchar4 for packed byte loads
6254        assert!(
6255            shaders.contains("uchar4"),
6256            "matmul_vec_q4 should use uchar4 for packed byte loads"
6257        );
6258        // Should unpack nibbles with &0xF and >>4
6259        assert!(
6260            shaders.contains("&0xF"),
6261            "matmul_vec_q4 should extract low nibble with &0xF"
6262        );
6263        assert!(
6264            shaders.contains(">>4"),
6265            "matmul_vec_q4 should extract high nibble with >>4"
6266        );
6267        // Should subtract 8 to convert unsigned to signed
6268        assert!(
6269            shaders.contains("-8)"),
6270            "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
6271        );
6272        // Should use 18-byte block size
6273        assert!(
6274            shaders.contains("blk * 18"),
6275            "matmul_vec_q4 should use 18-byte block stride"
6276        );
6277    }
6278
6279    #[test]
6280    fn q4_model_has_matmul_q4_pipeline() {
6281        let config = minimal_q4_config();
6282        let model_rs = generate_model_rs(&config).unwrap();
6283
6284        assert!(
6285            model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
6286            "MetalModel should have matmul_q4_pipeline field"
6287        );
6288        assert!(
6289            model_rs.contains("matmul_q4_pipeline,"),
6290            "MetalModel Self should include matmul_q4_pipeline"
6291        );
6292    }
6293
6294    #[test]
6295    fn q4_model_uses_dispatch_matmul_q4() {
6296        let config = minimal_q4_config();
6297        let model_rs = generate_model_rs(&config).unwrap();
6298
6299        assert!(
6300            model_rs.contains("dispatch_matmul_q4"),
6301            "Q4_0 model should use dispatch_matmul_q4 for projections"
6302        );
6303        assert!(
6304            model_rs.contains("fn dispatch_matmul_q4"),
6305            "model.rs should define dispatch_matmul_q4 method"
6306        );
6307    }
6308
6309    #[test]
6310    fn q4_model_loads_raw_bytes_not_dequantized() {
6311        let config = minimal_q4_config();
6312        let model_rs = generate_model_rs(&config).unwrap();
6313
6314        // Should NOT contain dequantization code
6315        assert!(
6316            !model_rs.contains("f16_to_f32"),
6317            "Q4_0 model should not dequantize weights to f32"
6318        );
6319
6320        // Should load raw Q4_0 bytes directly
6321        assert!(
6322            model_rs.contains("total_raw as u64"),
6323            "Q4_0 model should load raw bytes into Metal buffer"
6324        );
6325    }
6326
6327    #[test]
6328    fn q4_model_norms_stay_f32() {
6329        let config = minimal_q4_config();
6330        let model_rs = generate_model_rs(&config).unwrap();
6331
6332        assert!(
6333            model_rs.contains("let attn_norm = next_f32_buffer"),
6334            "attn_norm should use f32 buffer even for Q4_0 models"
6335        );
6336        assert!(
6337            model_rs.contains("let ffn_norm = next_f32_buffer"),
6338            "ffn_norm should use f32 buffer even for Q4_0 models"
6339        );
6340        assert!(
6341            model_rs.contains("let norm_buf = next_f32_buffer"),
6342            "final norm should use f32 buffer even for Q4_0 models"
6343        );
6344    }
6345
6346    #[test]
6347    fn q4_model_uses_fused_weight_loading() {
6348        let config = minimal_q4_config();
6349        let model_rs = generate_model_rs(&config).unwrap();
6350
6351        assert!(
6352            model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
6353            "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
6354        );
6355        assert!(
6356            model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
6357            "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
6358        );
6359        assert!(
6360            model_rs.contains("let o_weight = next_q4_buffer"),
6361            "Q4_0 model should use next_q4_buffer for O weight"
6362        );
6363        assert!(
6364            model_rs.contains("let down_weight = next_q4_buffer"),
6365            "Q4_0 model should use next_q4_buffer for down weight"
6366        );
6367    }
6368
6369    #[test]
6370    fn q4_dispatch_helper_takes_compute_encoder_ref() {
6371        let config = minimal_q4_config();
6372        let model_rs = generate_model_rs(&config).unwrap();
6373
6374        assert!(
6375            model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
6376            "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
6377        );
6378    }
6379
6380    #[test]
6381    fn f32_model_does_not_use_q4_dispatch() {
6382        let config = minimal_config();
6383        let model_rs = generate_model_rs(&config).unwrap();
6384
6385        let forward_start = model_rs
6386            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6387            .unwrap();
6388        let forward_body = &model_rs[forward_start..];
6389        let forward_end = forward_body
6390            .find("\n    fn dispatch_")
6391            .unwrap_or(forward_body.len());
6392        let forward_code = &forward_body[..forward_end];
6393
6394        assert!(
6395            !forward_code.contains("dispatch_matmul_q4"),
6396            "f32 model forward should not use dispatch_matmul_q4"
6397        );
6398    }
6399
6400    #[test]
6401    fn q4_model_lm_head_uses_q4_buffer() {
6402        let config = minimal_q4_config();
6403        let model_rs = generate_model_rs(&config).unwrap();
6404
6405        assert!(
6406            model_rs.contains("let lm_head_buf = next_q4_buffer"),
6407            "Q4_0 model should use next_q4_buffer for lm_head"
6408        );
6409    }
6410
6411    #[test]
6412    fn vec_tile_size_matches_model_dimensions() {
6413        // Small model: intermediate=128 > hidden=64, so vec_tile should be 128
6414        let small = minimal_config();
6415        let shaders_small = generate_metal_shaders(&small);
6416        assert!(
6417            shaders_small.contains("vec_tile[128]"),
6418            "vec_tile should be sized to max(hidden, intermediate) = 128"
6419        );
6420
6421        // Llama-3.2-1B-like config: intermediate=8192 > hidden=2048
6422        let mut large = minimal_config();
6423        large.hidden_size = 2048;
6424        large.intermediate_size = 8192;
6425        let shaders_large = generate_metal_shaders(&large);
6426        assert!(
6427            shaders_large.contains("vec_tile[8192]"),
6428            "vec_tile should be 8192 for models with intermediate=8192"
6429        );
6430        assert!(
6431            !shaders_large.contains("vec_tile[4096]"),
6432            "vec_tile should NOT be hardcoded to 4096"
6433        );
6434    }
6435
6436    #[test]
6437    fn generated_cargo_toml_has_server_deps() {
6438        let toml = generate_cargo_toml("my-model");
6439        assert!(
6440            toml.contains("tiny_http"),
6441            "Cargo.toml should depend on tiny_http"
6442        );
6443        assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
6444        assert!(
6445            toml.contains("serde_json"),
6446            "Cargo.toml should depend on serde_json"
6447        );
6448    }
6449
6450    #[test]
6451    fn generated_main_rs_has_serve_mode() {
6452        let config = minimal_config();
6453        let main_rs = generate_main_rs("test-model", &config).unwrap();
6454
6455        assert!(
6456            main_rs.contains("--serve"),
6457            "main.rs should parse --serve flag"
6458        );
6459        assert!(
6460            main_rs.contains("--port"),
6461            "main.rs should parse --port flag"
6462        );
6463        assert!(
6464            main_rs.contains("fn serve("),
6465            "main.rs should define serve function"
6466        );
6467        assert!(
6468            main_rs.contains("tiny_http::Server::http"),
6469            "main.rs should create tiny_http server"
6470        );
6471    }
6472
6473    #[test]
6474    fn generated_main_rs_has_chat_completions_endpoint() {
6475        let config = minimal_config();
6476        let main_rs = generate_main_rs("test-model", &config).unwrap();
6477
6478        assert!(
6479            main_rs.contains("/v1/chat/completions"),
6480            "main.rs should handle /v1/chat/completions endpoint"
6481        );
6482        assert!(
6483            main_rs.contains("/v1/models"),
6484            "main.rs should handle /v1/models endpoint"
6485        );
6486        assert!(
6487            main_rs.contains("/health"),
6488            "main.rs should handle /health endpoint"
6489        );
6490    }
6491
6492    #[test]
6493    fn generated_main_rs_has_sse_streaming() {
6494        let config = minimal_config();
6495        let main_rs = generate_main_rs("test-model", &config).unwrap();
6496
6497        assert!(
6498            main_rs.contains("text/event-stream"),
6499            "main.rs should set SSE content type for streaming"
6500        );
6501        assert!(
6502            main_rs.contains("chat.completion.chunk"),
6503            "main.rs should emit SSE chunks"
6504        );
6505        assert!(
6506            main_rs.contains("[DONE]"),
6507            "main.rs should emit [DONE] sentinel"
6508        );
6509    }
6510
6511    #[test]
6512    fn generated_main_rs_has_chat_message_formatting() {
6513        let config = minimal_config();
6514        let main_rs = generate_main_rs("test-model", &config).unwrap();
6515
6516        assert!(
6517            main_rs.contains("fn format_chat_messages"),
6518            "main.rs should define format_chat_messages function"
6519        );
6520        assert!(
6521            main_rs.contains("<|im_start|>"),
6522            "main.rs should use ChatML format"
6523        );
6524        assert!(
6525            main_rs.contains("<|im_end|>"),
6526            "main.rs should use ChatML format"
6527        );
6528    }
6529
6530    #[test]
6531    fn generated_main_rs_has_request_types() {
6532        let config = minimal_config();
6533        let main_rs = generate_main_rs("test-model", &config).unwrap();
6534
6535        assert!(
6536            main_rs.contains("struct ChatRequest"),
6537            "main.rs should define ChatRequest struct"
6538        );
6539        assert!(
6540            main_rs.contains("struct ChatMessage"),
6541            "main.rs should define ChatMessage struct"
6542        );
6543        assert!(
6544            main_rs.contains("Deserialize"),
6545            "main.rs should derive Deserialize for request types"
6546        );
6547    }
6548
6549    #[test]
6550    fn generated_model_has_reset_method() {
6551        let config = minimal_config();
6552        let model_rs = generate_model_rs(&config).unwrap();
6553
6554        assert!(
6555            model_rs.contains("pub fn reset(&mut self)"),
6556            "model.rs should have a reset() method for multi-request serving"
6557        );
6558        assert!(
6559            model_rs.contains("self.pos = 0"),
6560            "reset() should reset position to 0"
6561        );
6562    }
6563
6564    #[test]
6565    fn generated_main_rs_cli_mode_still_works() {
6566        let config = minimal_config();
6567        let main_rs = generate_main_rs("test-model", &config).unwrap();
6568
6569        // CLI mode should still be functional
6570        assert!(
6571            main_rs.contains("fn cli_mode("),
6572            "main.rs should define cli_mode function"
6573        );
6574        assert!(
6575            main_rs.contains("model.forward"),
6576            "main.rs should call model.forward"
6577        );
6578        assert!(
6579            main_rs.contains("model.forward_prefill"),
6580            "main.rs should call model.forward_prefill"
6581        );
6582    }
6583
6584    // ── Batched prefill tests ──────────────────────────────────────────
6585
6586    #[test]
6587    fn generated_shaders_contain_batch_kernels() {
6588        let shaders = generate_metal_shaders(&minimal_config());
6589
6590        assert!(
6591            shaders.contains("kernel void matmul_vec_batch"),
6592            "shaders should contain matmul_vec_batch kernel"
6593        );
6594        assert!(
6595            shaders.contains("kernel void matmul_vec_q8_batch"),
6596            "shaders should contain matmul_vec_q8_batch kernel"
6597        );
6598        assert!(
6599            shaders.contains("kernel void matmul_q8_gemm_batch"),
6600            "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
6601        );
6602        assert!(
6603            shaders.contains("kernel void matmul_vec_q4_batch"),
6604            "shaders should contain matmul_vec_q4_batch kernel"
6605        );
6606        assert!(
6607            shaders.contains("kernel void rms_norm_batch"),
6608            "shaders should contain rms_norm_batch kernel"
6609        );
6610        assert!(
6611            shaders.contains("kernel void silu_mul_fused_batch"),
6612            "shaders should contain silu_mul_fused_batch kernel"
6613        );
6614        assert!(
6615            shaders.contains("kernel void add_inplace_batch"),
6616            "shaders should contain add_inplace_batch kernel"
6617        );
6618        assert!(
6619            shaders.contains("kernel void copy_embedding_batch"),
6620            "shaders should contain copy_embedding_batch kernel"
6621        );
6622    }
6623
6624    #[test]
6625    fn generated_model_has_batch_pipelines() {
6626        let config = minimal_config();
6627        let model_rs = generate_model_rs(&config).unwrap();
6628
6629        for pipeline in &[
6630            "matmul_batch_pipeline",
6631            "matmul_q8_batch_pipeline",
6632            "matmul_q8_gemm_batch_pipeline",
6633            "matmul_q4_batch_pipeline",
6634            "rms_norm_batch_pipeline",
6635            "rope_batch_pipeline",
6636            "silu_mul_fused_batch_pipeline",
6637            "add_inplace_batch_pipeline",
6638            "copy_embedding_batch_pipeline",
6639            "attention_batch_pipeline",
6640            "copy_kv_batch_pipeline",
6641        ] {
6642            assert!(
6643                model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
6644                "MetalModel should have {pipeline} field"
6645            );
6646        }
6647    }
6648
6649    #[test]
6650    fn generated_model_has_batch_buffers() {
6651        let config = minimal_config();
6652        let model_rs = generate_model_rs(&config).unwrap();
6653
6654        for buf in &[
6655            "batch_hidden_buf",
6656            "batch_residual_buf",
6657            "batch_qkv_buf",
6658            "batch_attn_out_buf",
6659            "batch_attn_proj_buf",
6660            "batch_gate_up_buf",
6661            "batch_ffn_hidden_buf",
6662            "batch_ffn_out_buf",
6663            "batch_tokens_buf",
6664            "batch_positions_buf",
6665        ] {
6666            assert!(
6667                model_rs.contains(&format!("{buf}: Buffer")),
6668                "MetalModel should have {buf} field"
6669            );
6670        }
6671    }
6672
6673    #[test]
6674    fn generated_model_has_forward_prefill_batch() {
6675        let config = minimal_config();
6676        let model_rs = generate_model_rs(&config).unwrap();
6677
6678        assert!(
6679            model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
6680            "MetalModel should have forward_prefill_batch method"
6681        );
6682
6683        // forward_prefill should delegate to forward_prefill_batch
6684        assert!(
6685            model_rs.contains("self.forward_prefill_batch(&[token_id])"),
6686            "forward_prefill should delegate to forward_prefill_batch"
6687        );
6688    }
6689
6690    #[test]
6691    fn generated_model_has_max_batch_size_constant() {
6692        let config = minimal_config();
6693        let model_rs = generate_model_rs(&config).unwrap();
6694
6695        assert!(
6696            model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
6697            "model.rs should define MAX_BATCH_SIZE constant"
6698        );
6699    }
6700
6701    #[test]
6702    fn forward_prefill_batch_uses_batch_dispatch() {
6703        let config = minimal_config();
6704        let model_rs = generate_model_rs(&config).unwrap();
6705
6706        let batch_start = model_rs
6707            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
6708            .unwrap();
6709        let batch_body = &model_rs[batch_start..];
6710        let batch_end = batch_body
6711            .find("\n    pub fn reset")
6712            .unwrap_or(batch_body.len());
6713        let batch_code = &batch_body[..batch_end];
6714
6715        // Should use batched dispatch methods
6716        assert!(
6717            batch_code.contains("dispatch_rms_norm_batch"),
6718            "forward_prefill_batch should use dispatch_rms_norm_batch"
6719        );
6720        assert!(
6721            batch_code.contains("dispatch_copy_embedding_batch"),
6722            "forward_prefill_batch should use dispatch_copy_embedding_batch"
6723        );
6724        assert!(
6725            batch_code.contains("dispatch_silu_mul_fused_batch"),
6726            "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
6727        );
6728        // Should use batched causal attention dispatch
6729        assert!(
6730            batch_code.contains("dispatch_attention_batch"),
6731            "forward_prefill_batch should use dispatch_attention_batch"
6732        );
6733        // Should use fused KV cache copy (both K and V in one dispatch)
6734        assert!(
6735            batch_code.contains("dispatch_copy_kv_both_batch"),
6736            "forward_prefill_batch should use dispatch_copy_kv_both_batch"
6737        );
6738        // Should use fused RoPE Q+K dispatch
6739        assert!(
6740            batch_code.contains("dispatch_rope_qk_batch"),
6741            "forward_prefill_batch should use dispatch_rope_qk_batch"
6742        );
6743    }
6744
6745    #[test]
6746    fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
6747        let config = minimal_q8_config();
6748        let model_rs = generate_model_rs(&config).unwrap();
6749
6750        let batch_start = model_rs
6751            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
6752            .unwrap();
6753        let batch_body = &model_rs[batch_start..];
6754        let batch_end = batch_body
6755            .find("\n    pub fn reset")
6756            .unwrap_or(batch_body.len());
6757        let batch_code = &batch_body[..batch_end];
6758
6759        assert!(
6760            batch_code.contains("dispatch_matmul_q8_batch"),
6761            "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
6762        );
6763    }
6764
6765    #[test]
6766    fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
6767        let config = minimal_q4_config();
6768        let model_rs = generate_model_rs(&config).unwrap();
6769
6770        let batch_start = model_rs
6771            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
6772            .unwrap();
6773        let batch_body = &model_rs[batch_start..];
6774        let batch_end = batch_body
6775            .find("\n    pub fn reset")
6776            .unwrap_or(batch_body.len());
6777        let batch_code = &batch_body[..batch_end];
6778
6779        assert!(
6780            batch_code.contains("dispatch_matmul_q4_batch"),
6781            "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
6782        );
6783    }
6784
6785    #[test]
6786    fn generated_main_rs_uses_batched_prefill() {
6787        let config = minimal_config();
6788        let main_rs = generate_main_rs("test-model", &config).unwrap();
6789
6790        assert!(
6791            main_rs.contains("forward_prefill_batch"),
6792            "main.rs should use forward_prefill_batch for prompt tokens"
6793        );
6794    }
6795
6796    #[test]
6797    fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
6798        let config = minimal_config();
6799        let model_rs = generate_model_rs(&config).unwrap();
6800
6801        let batch_start = model_rs
6802            .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
6803            .unwrap();
6804        let batch_body = &model_rs[batch_start..];
6805        let batch_end = batch_body
6806            .find("\n    pub fn reset")
6807            .unwrap_or(batch_body.len());
6808        let batch_code = &batch_body[..batch_end];
6809
6810        assert!(
6811            batch_code.contains("dispatch_matmul_batch"),
6812            "f32 forward_prefill_batch should use dispatch_matmul_batch"
6813        );
6814        // Should NOT use Q8 or Q4 batch dispatch
6815        assert!(
6816            !batch_code.contains("dispatch_matmul_q8_batch"),
6817            "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
6818        );
6819        assert!(
6820            !batch_code.contains("dispatch_matmul_q4_batch"),
6821            "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
6822        );
6823    }
6824
6825    #[test]
6826    fn forward_uses_cpu_embedding_lookup() {
6827        let config = minimal_config();
6828        let model_rs = generate_model_rs(&config).unwrap();
6829
6830        // Find just the forward() body (not forward_profile)
6831        let forward_start = model_rs
6832            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6833            .unwrap();
6834        let forward_body = &model_rs[forward_start..];
6835        let forward_end = forward_body
6836            .find("\n    pub fn forward_profile")
6837            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
6838            .unwrap_or(forward_body.len());
6839        let forward_code = &forward_body[..forward_end];
6840
6841        // forward() should use CPU memcpy for embedding lookup (unified memory)
6842        assert!(
6843            forward_code.contains("embed_buf.contents()"),
6844            "forward() should access embed_buf via CPU unified memory for embedding lookup"
6845        );
6846        assert!(
6847            forward_code.contains("copy_nonoverlapping"),
6848            "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
6849        );
6850        // forward() should NOT use GPU dispatch for embedding
6851        assert!(
6852            !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
6853            "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
6854        );
6855    }
6856
6857    #[test]
6858    fn forward_profile_method_exists() {
6859        let config = minimal_config();
6860        let model_rs = generate_model_rs(&config).unwrap();
6861
6862        assert!(
6863            model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
6864            "MetalModel should have forward_profile() method"
6865        );
6866        // Profile method should print timing information
6867        assert!(
6868            model_rs.contains("[profile]"),
6869            "forward_profile() should print timing with [profile] prefix"
6870        );
6871        assert!(
6872            model_rs.contains("d_embed"),
6873            "forward_profile() should measure embedding time"
6874        );
6875        assert!(
6876            model_rs.contains("d_layers"),
6877            "forward_profile() should measure layer time"
6878        );
6879        assert!(
6880            model_rs.contains("d_logits"),
6881            "forward_profile() should measure logits time"
6882        );
6883    }
6884
6885    #[test]
6886    fn generated_cli_has_profile_flag() {
6887        let config = minimal_config();
6888        let main_rs = generate_main_rs("test-model", &config).unwrap();
6889
6890        assert!(
6891            main_rs.contains("--profile"),
6892            "CLI should support --profile flag"
6893        );
6894        assert!(
6895            main_rs.contains("forward_profile"),
6896            "CLI should call forward_profile when --profile is set"
6897        );
6898    }
6899
6900    #[test]
6901    fn generated_cli_has_thermal_yield() {
6902        let config = minimal_config();
6903        let main_rs = generate_main_rs("test-model", &config).unwrap();
6904
6905        assert!(
6906            main_rs.contains("yield_now()"),
6907            "CLI generation loop should include thread::yield_now() for thermal management"
6908        );
6909    }
6910
6911    // ── Real-world validation tests ──────────────────────────────────────
6912
6913    #[test]
6914    fn generated_forward_handles_single_token_prompt() {
6915        // With a single token (the first prompt token), forward() should work
6916        // at pos=0 where seq_len=1. The attention kernel must handle the case
6917        // where there is only one KV entry (no prefill context).
6918        let config = minimal_config();
6919        let model_rs = generate_model_rs(&config).unwrap();
6920
6921        // The forward function should accept any u32 token_id (no minimum pos guard)
6922        let forward_start = model_rs
6923            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6924            .expect("forward() must exist");
6925        let forward_body = &model_rs[forward_start..forward_start + 400];
6926
6927        // Should NOT require pos > 0 or seq_len > 1
6928        assert!(
6929            !forward_body.contains("assert!(self.pos > 0"),
6930            "forward() must accept pos=0 (first token with no prefill)"
6931        );
6932
6933        // The attention kernel should handle seq_len=1 via the pos field
6934        assert!(
6935            model_rs.contains("self.pos"),
6936            "forward() should use self.pos to track sequence position"
6937        );
6938    }
6939
6940    #[test]
6941    fn generated_reset_clears_kv_cache_position() {
6942        // After reset(), the model should be in a clean state. The pos field
6943        // must be 0 so new generation starts from scratch.
6944        let config = minimal_config();
6945        let model_rs = generate_model_rs(&config).unwrap();
6946
6947        let reset_start = model_rs
6948            .find("pub fn reset(&mut self)")
6949            .expect("reset() must exist");
6950        let reset_body = &model_rs[reset_start..reset_start + 200];
6951
6952        // Reset must zero the position counter
6953        assert!(
6954            reset_body.contains("self.pos = 0"),
6955            "reset() must set self.pos = 0"
6956        );
6957
6958        // Verify reset clears prev_cmd (double-buffering state)
6959        assert!(
6960            reset_body.contains("self.prev_cmd = None"),
6961            "reset() should clear prev_cmd for clean command buffer state"
6962        );
6963    }
6964
6965    #[test]
6966    fn generated_serve_handles_empty_messages_gracefully() {
6967        // The serve endpoint should not crash when receiving an empty messages array.
6968        // The format_chat_messages function should handle this gracefully.
6969        let config = minimal_config();
6970        let main_rs = generate_main_rs("test-model", &config).unwrap();
6971
6972        // The format_chat_messages function should exist and handle empty input
6973        let format_fn_start = main_rs
6974            .find("fn format_chat_messages")
6975            .expect("format_chat_messages must exist");
6976        let format_fn_body =
6977            &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
6978
6979        // It should iterate over messages (an empty slice produces an empty loop)
6980        assert!(
6981            format_fn_body.contains("for msg in messages"),
6982            "format_chat_messages should iterate over the messages slice"
6983        );
6984        // It should always append the assistant prompt suffix
6985        assert!(
6986            format_fn_body.contains("<|im_start|>assistant"),
6987            "format_chat_messages should always append assistant prompt header"
6988        );
6989
6990        // The serve function should call model.reset() before each request
6991        let serve_fn_start = main_rs
6992            .find("fn serve(")
6993            .expect("serve function must exist");
6994        let serve_fn_body = &main_rs[serve_fn_start..];
6995        assert!(
6996            serve_fn_body.contains("model.reset()"),
6997            "serve function should reset model between requests"
6998        );
6999    }
7000
7001    #[test]
7002    fn generated_model_forward_increments_pos() {
7003        // Each forward() call must increment self.pos so the next token
7004        // uses the correct RoPE position and KV cache offset.
7005        let config = minimal_config();
7006        let model_rs = generate_model_rs(&config).unwrap();
7007
7008        let forward_start = model_rs
7009            .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7010            .unwrap();
7011        let forward_body = &model_rs[forward_start..];
7012        let forward_end = forward_body
7013            .find("\n    pub fn forward_profile")
7014            .or_else(|| forward_body.find("\n    pub fn forward_prefill"))
7015            .or_else(|| forward_body.find("\n    fn dispatch_"))
7016            .unwrap_or(forward_body.len());
7017        let forward_code = &forward_body[..forward_end];
7018
7019        assert!(
7020            forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
7021            "forward() must increment self.pos after processing a token"
7022        );
7023    }
7024}