1use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16 #[error("graph has no model config")]
18 MissingConfig,
19
20 #[error("I/O error: {0}")]
22 Io(#[from] std::io::Error),
23
24 #[error("format error: {0}")]
26 Fmt(#[from] std::fmt::Error),
27}
28
29pub fn generate_metal_project(
37 graph: &Graph,
38 output_dir: &Path,
39 model_name: &str,
40) -> Result<(), MetalCodegenError> {
41 let config = graph
42 .config
43 .as_ref()
44 .ok_or(MetalCodegenError::MissingConfig)?;
45
46 let src_dir = output_dir.join("src");
47 let shader_dir = output_dir.join("shaders");
48 fs::create_dir_all(&src_dir)?;
49 fs::create_dir_all(&shader_dir)?;
50
51 fs::write(
52 output_dir.join("Cargo.toml"),
53 generate_cargo_toml(model_name),
54 )?;
55
56 fs::write(
57 shader_dir.join("kernels.metal"),
58 generate_metal_shaders(config),
59 )?;
60
61 let model_rs = generate_model_rs(config)?;
62 fs::write(src_dir.join("model.rs"), model_rs)?;
63
64 let main_rs = generate_main_rs(model_name, config)?;
65 fs::write(src_dir.join("main.rs"), main_rs)?;
66
67 Ok(())
68}
69
70fn sanitize_name(name: &str) -> String {
75 name.to_lowercase()
76 .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77 .trim_matches('-')
78 .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82 let sanitized = sanitize_name(model_name);
83 format!(
84 r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108 )
109}
110
111fn generate_metal_shaders(config: &ModelConfig) -> String {
116 let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121 let attn_scores_size = config.max_seq_len.min(4096);
126 r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154 device const float* matrix [[buffer(0)]],
155 device const float* vector [[buffer(1)]],
156 device float* output [[buffer(2)]],
157 constant uint& rows [[buffer(3)]],
158 constant uint& cols [[buffer(4)]],
159 uint tgid [[threadgroup_position_in_grid]],
160 uint tid [[thread_index_in_threadgroup]],
161 uint simd_lane [[thread_index_in_simdgroup]],
162 uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164 // Cooperatively load vector into threadgroup shared memory
165 threadgroup float vec_tile[VEC_TILE_SIZE]; // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166 for (uint i = tid; i < cols; i += 256) {
167 vec_tile[i] = vector[i];
168 }
169 threadgroup_barrier(mem_flags::mem_threadgroup);
170
171 // Each simdgroup handles 8 consecutive rows
172 uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173 if (row_base >= rows) return;
174
175 uint base0 = row_base * cols;
176 uint base1 = (row_base + 1) * cols;
177 uint base2 = (row_base + 2) * cols;
178 uint base3 = (row_base + 3) * cols;
179 uint base4 = (row_base + 4) * cols;
180 uint base5 = (row_base + 5) * cols;
181 uint base6 = (row_base + 6) * cols;
182 uint base7 = (row_base + 7) * cols;
183
184 // float4 vectorized accumulation across 8 rows
185 uint cols_vec4 = cols & ~127u; // largest multiple of 128 <= cols
186 float4 sum4_0 = float4(0.0f);
187 float4 sum4_1 = float4(0.0f);
188 float4 sum4_2 = float4(0.0f);
189 float4 sum4_3 = float4(0.0f);
190 float4 sum4_4 = float4(0.0f);
191 float4 sum4_5 = float4(0.0f);
192 float4 sum4_6 = float4(0.0f);
193 float4 sum4_7 = float4(0.0f);
194
195 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205 }
206
207 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216 // Handle remaining elements (cols not divisible by 128)
217 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218 float vv = vec_tile[j];
219 sum0 += matrix[base0 + j] * vv;
220 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227 }
228
229 // Simdgroup hardware warp-level reduction
230 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235 // Only first lane writes the results
236 if (simd_lane == 0) {
237 if (row_base < rows) output[row_base] = sum0;
238 if (row_base + 1 < rows) output[row_base + 1] = sum1;
239 if (row_base + 2 < rows) output[row_base + 2] = sum2;
240 if (row_base + 3 < rows) output[row_base + 3] = sum3;
241 if (row_base + 4 < rows) output[row_base + 4] = sum4;
242 if (row_base + 5 < rows) output[row_base + 5] = sum5;
243 if (row_base + 6 < rows) output[row_base + 6] = sum6;
244 if (row_base + 7 < rows) output[row_base + 7] = sum7;
245 }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253 device const float* input [[buffer(0)]],
254 device const float* weight [[buffer(1)]],
255 device float* output [[buffer(2)]],
256 constant uint& n [[buffer(3)]],
257 constant float& eps [[buffer(4)]],
258 uint tid [[thread_index_in_threadgroup]])
259{
260 // Each thread accumulates partial sum-of-squares
261 float sum_sq = 0.0f;
262 for (uint i = tid; i < n; i += 256) {
263 float v = input[i];
264 sum_sq += v * v;
265 }
266
267 // Simdgroup-level reduction (hardware warp sum)
268 sum_sq = simd_sum(sum_sq);
269
270 // Cross-simdgroup reduction via shared memory
271 threadgroup float shared[8];
272 uint simd_id = tid / 32;
273 uint simd_lane = tid % 32;
274 if (simd_lane == 0) {
275 shared[simd_id] = sum_sq;
276 }
277 threadgroup_barrier(mem_flags::mem_threadgroup);
278
279 // First thread computes final inverse RMS
280 if (tid == 0) {
281 float total = 0.0f;
282 for (uint i = 0; i < 8; i++) {
283 total += shared[i];
284 }
285 shared[0] = fast::rsqrt(total / float(n) + eps);
286 }
287 threadgroup_barrier(mem_flags::mem_threadgroup);
288
289 float inv_rms = shared[0];
290
291 // Normalize
292 for (uint i = tid; i < n; i += 256) {
293 output[i] = input[i] * inv_rms * weight[i];
294 }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301 device float* data [[buffer(0)]],
302 constant uint& num_heads [[buffer(1)]],
303 constant uint& head_dim [[buffer(2)]],
304 constant uint& pos [[buffer(3)]],
305 constant float& theta [[buffer(4)]],
306 uint id [[thread_position_in_grid]])
307{
308 uint half_dim = head_dim / 2;
309 uint total_pairs = num_heads * half_dim;
310 if (id >= total_pairs) return;
311
312 uint h = id / half_dim;
313 uint i = id % half_dim;
314 uint off = h * head_dim;
315
316 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317 float angle = float(pos) * freq;
318 float c = cos(angle);
319 float s = sin(angle);
320
321 float x0 = data[off + 2 * i];
322 float x1 = data[off + 2 * i + 1];
323 data[off + 2 * i] = x0 * c - x1 * s;
324 data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331 device float* data [[buffer(0)]],
332 constant uint& n [[buffer(1)]],
333 uint tid [[thread_index_in_threadgroup]],
334 uint tg_size [[threads_per_threadgroup]])
335{
336 threadgroup float shared_val[256];
337
338 // Pass 1: find max
339 float local_max = -INFINITY;
340 for (uint i = tid; i < n; i += tg_size) {
341 local_max = max(local_max, data[i]);
342 }
343 shared_val[tid] = local_max;
344 threadgroup_barrier(mem_flags::mem_threadgroup);
345
346 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347 if (tid < stride) {
348 shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349 }
350 threadgroup_barrier(mem_flags::mem_threadgroup);
351 }
352 float max_val = shared_val[0];
353 threadgroup_barrier(mem_flags::mem_threadgroup);
354
355 // Pass 2: exp and sum
356 float local_sum = 0.0f;
357 for (uint i = tid; i < n; i += tg_size) {
358 float e = fast::exp(data[i] - max_val);
359 data[i] = e;
360 local_sum += e;
361 }
362 shared_val[tid] = local_sum;
363 threadgroup_barrier(mem_flags::mem_threadgroup);
364
365 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366 if (tid < stride) {
367 shared_val[tid] += shared_val[tid + stride];
368 }
369 threadgroup_barrier(mem_flags::mem_threadgroup);
370 }
371 float inv_sum = 1.0f / shared_val[0];
372 threadgroup_barrier(mem_flags::mem_threadgroup);
373
374 // Pass 3: normalize
375 for (uint i = tid; i < n; i += tg_size) {
376 data[i] *= inv_sum;
377 }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382// output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384 device const float* gate [[buffer(0)]],
385 device const float* up [[buffer(1)]],
386 device float* output [[buffer(2)]],
387 constant uint& n [[buffer(3)]],
388 uint id [[thread_position_in_grid]])
389{
390 if (id >= n) return;
391 float g = gate[id];
392 output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397// gate = gate_up[0..n], up = gate_up[n..2*n]
398// output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400 device const float* gate_up [[buffer(0)]],
401 device float* output [[buffer(1)]],
402 constant uint& n [[buffer(2)]],
403 uint id [[thread_position_in_grid]])
404{
405 if (id >= n) return;
406 float g = gate_up[id];
407 float u = gate_up[n + id];
408 output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414 device const float* a [[buffer(0)]],
415 device const float* b [[buffer(1)]],
416 device float* output [[buffer(2)]],
417 constant uint& n [[buffer(3)]],
418 uint id [[thread_position_in_grid]])
419{
420 if (id >= n) return;
421 output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428 device const float* src [[buffer(0)]],
429 device float* dst [[buffer(1)]],
430 constant uint& count [[buffer(2)]],
431 uint id [[thread_position_in_grid]])
432{
433 if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440 device const float* src [[buffer(0)]],
441 device float* dst [[buffer(1)]],
442 constant uint& src_offset [[buffer(2)]], // in floats
443 constant uint& count [[buffer(3)]],
444 uint id [[thread_position_in_grid]])
445{
446 if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── copy_f32_to_f16_offset ──────────────────────────────────────────────
450// Copy f32 elements from src into a half-typed dst, converting to half on
451// write. Used by the single-token decode path to append a new K/V vector
452// to the f16 KV cache. Byte offsets into src/dst are supplied via the
453// Metal buffer binding offsets — no in-kernel offsets needed.
454kernel void copy_f32_to_f16_offset(
455 device const float* src [[buffer(0)]],
456 device half* dst [[buffer(1)]],
457 constant uint& count [[buffer(2)]],
458 uint id [[thread_position_in_grid]])
459{
460 if (id < count) dst[id] = half(src[id]);
461}
462
463// ── add_inplace ─────────────────────────────────────────────────────────
464// In-place residual connection: a[i] += b[i]
465// Avoids a separate blit copy for residual add, reducing encoder overhead.
466kernel void add_inplace(
467 device float* a [[buffer(0)]],
468 device const float* b [[buffer(1)]],
469 constant uint& n [[buffer(2)]],
470 uint id [[thread_position_in_grid]])
471{
472 if (id >= n) return;
473 a[id] += b[id];
474}
475
476// ── matmul_vec_q8 ─────────────────────────────────────────────────────
477// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
478// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
479// Operates directly on quantized weights to halve memory bandwidth vs f32,
480// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
481//
482// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
483// because int8->float conversion doubles register demand. Fully unrolled
484// inner loop with float4 vector loads from shared memory eliminates loop
485// overhead and enables better instruction scheduling.
486// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
487constant constexpr uint Q8_ROWS_PER_SG = 4;
488constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
489
490// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
491// doubles ALU work, so the same register budget applies).
492constant constexpr uint Q4_ROWS_PER_SG = 4;
493constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
494
495kernel void matmul_vec_q8(
496 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes
497 device const float* vector [[buffer(1)]], // f32 input
498 device float* output [[buffer(2)]],
499 constant uint& rows [[buffer(3)]],
500 constant uint& cols [[buffer(4)]], // number of elements per row
501 uint tgid [[threadgroup_position_in_grid]],
502 uint tid [[thread_index_in_threadgroup]],
503 uint simd_lane [[thread_index_in_simdgroup]],
504 uint simd_id [[simdgroup_index_in_threadgroup]])
505{
506 // Load vector into shared memory
507 threadgroup float vec_tile[VEC_TILE_SIZE];
508 for (uint i = tid; i < cols; i += 256) {
509 vec_tile[i] = vector[i];
510 }
511 threadgroup_barrier(mem_flags::mem_threadgroup);
512
513 // Each simdgroup handles 4 consecutive rows (lower register pressure)
514 uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
515 if (row_base >= rows) return;
516
517 // Q8_0: each block is 34 bytes for 32 elements
518 uint blocks_per_row = cols / 32;
519 uint row_bytes = blocks_per_row * 34;
520
521 // Pointers to each row's Q8_0 data
522 device const uchar* r0 = matrix + row_base * row_bytes;
523 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
524 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
525 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
526
527 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
528
529 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
530 uint bb = blk * 34;
531 uint vb = blk * 32;
532
533 // Prefetch all 4 scales
534 float sc0 = float(*(device const half*)(r0 + bb));
535 float sc1 = float(*(device const half*)(r1 + bb));
536 float sc2 = float(*(device const half*)(r2 + bb));
537 float sc3 = float(*(device const half*)(r3 + bb));
538
539 // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
540 // Q8_0 block layout where the int8 data starts at offset +2 from a
541 // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
542 // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
543 // reduction in memory transactions. Metal's char16/packed_char16 are
544 // reserved types and packed_*int4 require >=4-byte alignment which
545 // this layout does not provide, so packed_short4 is the widest valid
546 // vectorized load.
547 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
548 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
549 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
550 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
551
552 // Load all 8 float4 vector values for this 32-element block from shared memory
553 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
554 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
555 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
556 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
557 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
558 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
559 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
560 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
561
562 // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
563 // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
564 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
565 short4 _s = short4(SHORT4); \
566 char2 _a = as_type<char2>(_s.x); \
567 char2 _b = as_type<char2>(_s.y); \
568 char2 _c = as_type<char2>(_s.z); \
569 char2 _d = as_type<char2>(_s.w); \
570 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
571 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
572 }
573
574 float4 f0, f1;
575 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
576
577 // Row 0: 4 short4 loads cover 32 int8 weights
578 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
579 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
580 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
581 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
582
583 // Row 1
584 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
585 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
586 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
587 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
588
589 // Row 2
590 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
591 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
592 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
593 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
594
595 // Row 3
596 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
597 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
598 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
599 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
600
601 #undef Q8_UNPACK8
602
603 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
604 }
605
606 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
607 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
608
609 if (simd_lane == 0) {
610 if (row_base < rows) output[row_base] = sum0;
611 if (row_base + 1 < rows) output[row_base + 1] = sum1;
612 if (row_base + 2 < rows) output[row_base + 2] = sum2;
613 if (row_base + 3 < rows) output[row_base + 3] = sum3;
614 }
615}
616
617// ── matmul_vec_q4 ─────────────────────────────────────────────────────
618// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
619// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
620// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
621// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
622//
623// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
624// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
625kernel void matmul_vec_q4(
626 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes
627 device const float* vector [[buffer(1)]], // f32 input
628 device float* output [[buffer(2)]],
629 constant uint& rows [[buffer(3)]],
630 constant uint& cols [[buffer(4)]], // number of elements per row
631 uint tgid [[threadgroup_position_in_grid]],
632 uint tid [[thread_index_in_threadgroup]],
633 uint simd_lane [[thread_index_in_simdgroup]],
634 uint simd_id [[simdgroup_index_in_threadgroup]])
635{
636 // Load vector into shared memory
637 threadgroup float vec_tile[VEC_TILE_SIZE];
638 for (uint i = tid; i < cols; i += 256) {
639 vec_tile[i] = vector[i];
640 }
641 threadgroup_barrier(mem_flags::mem_threadgroup);
642
643 // Each simdgroup handles 4 consecutive rows
644 uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
645 if (row_base >= rows) return;
646
647 // Q4_0: each block is 18 bytes for 32 elements
648 uint blocks_per_row = cols / 32;
649 uint row_bytes = blocks_per_row * 18;
650
651 // Pointers to each row's Q4_0 data
652 device const uchar* r0 = matrix + row_base * row_bytes;
653 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
654 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
655 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
656
657 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
658
659 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
660 uint bb = blk * 18;
661 uint vb = blk * 32;
662
663 // Prefetch all 4 scales
664 float sc0 = float(*(device const half*)(r0 + bb));
665 float sc1 = float(*(device const half*)(r1 + bb));
666 float sc2 = float(*(device const half*)(r2 + bb));
667 float sc3 = float(*(device const half*)(r3 + bb));
668
669 // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
670 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
671 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
672 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
673 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
674
675 // Load 8 float4 vector values for 32 elements from shared memory
676 // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
677 float4 v0 = *(threadgroup const float4*)(vec_tile + vb); // [0..3]
678 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4); // [4..7]
679 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8); // [8..11]
680 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12); // [12..15]
681 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16); // [16..19]
682 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20); // [20..23]
683 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24); // [24..27]
684 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28); // [28..31]
685
686 // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
687 // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
688 float bd0=0, bd1=0, bd2=0, bd3=0;
689 uchar4 b;
690
691 // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
692 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
693 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
694 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
695 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
696 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
697 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
698 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
699 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
700 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
701 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
702 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
703 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
704 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
705 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
706 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
707 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
708
709 // Row 1
710 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
711 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
712 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
713 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
714 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
715 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
716 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
717 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
718 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
719 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
720 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
721 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
722 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
723 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
724 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
725 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
726
727 // Row 2
728 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
729 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
730 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
731 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
732 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
733 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
734 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
735 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
736 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
737 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
738 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
739 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
740 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
741 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
742 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
743 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
744
745 // Row 3
746 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
747 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
748 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
749 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
750 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
751 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
752 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
753 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
754 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
755 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
756 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
757 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
758 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
759 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
760 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
761 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
762
763 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
764 }
765
766 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
767 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
768
769 if (simd_lane == 0) {
770 if (row_base < rows) output[row_base] = sum0;
771 if (row_base + 1 < rows) output[row_base + 1] = sum1;
772 if (row_base + 2 < rows) output[row_base + 2] = sum2;
773 if (row_base + 3 < rows) output[row_base + 3] = sum3;
774 }
775}
776
777// ── attention ───────────────────────────────────────────────────────────
778// Single-query attention with simdgroup cooperative reductions.
779// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
780// with simd_max/simd_sum reductions, then weighted sum of V.
781// Each threadgroup handles one head with 256 threads (8 simdgroups).
782//
783// Buffers:
784// q: [num_heads * head_dim] current query
785// k_cache: [max_seq_len * num_kv_heads * head_dim]
786// v_cache: [max_seq_len * num_kv_heads * head_dim]
787// output: [num_heads * head_dim]
788kernel void attention(
789 device const float* q [[buffer(0)]],
790 device const half* k_cache [[buffer(1)]],
791 device const half* v_cache [[buffer(2)]],
792 device float* output [[buffer(3)]],
793 constant uint& seq_len [[buffer(4)]],
794 constant uint& num_heads [[buffer(5)]],
795 constant uint& num_kv_heads [[buffer(6)]],
796 constant uint& head_dim [[buffer(7)]],
797 uint tgid [[threadgroup_position_in_grid]],
798 uint tid [[thread_index_in_threadgroup]],
799 uint simd_lane [[thread_index_in_simdgroup]],
800 uint simd_id [[simdgroup_index_in_threadgroup]])
801{
802 uint head = tgid;
803 if (head >= num_heads) return;
804 uint kv_head = head / (num_heads / num_kv_heads);
805
806 uint q_off = head * head_dim;
807
808 // Step 1: Compute attention scores Q·K^T with simdgroup reduction
809 // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
810 // vs 4096. For seq_len > 2048, generation-phase attention is rare;
811 // most generation steps have short effective context.
812 threadgroup float scores[ATTN_SCORES_SIZE]; // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
813
814 // Q·K^T with half4/float4 vectorized loads.
815 // Each simdgroup handles one s; 32 lanes cover head_dim in chunks of 4.
816 // For head_dim=128, every lane does exactly one half4 load (no loop).
817 // For head_dim=64, 16 lanes active; for head_dim=96, 24 lanes.
818 uint head_dim4 = head_dim / 4;
819 for (uint s = simd_id; s < seq_len; s += 8) {
820 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
821 float dot = 0.0;
822 for (uint d4 = simd_lane; d4 < head_dim4; d4 += 32) {
823 uint d = d4 * 4;
824 float4 q4 = *(device const float4*)(q + q_off + d);
825 float4 k4 = float4(*(device const half4*)(k_cache + k_off + d));
826 dot += q4.x * k4.x + q4.y * k4.y + q4.z * k4.z + q4.w * k4.w;
827 }
828 // Scalar fallback for head_dim not divisible by 4 (unused for all
829 // current models: head_dim ∈ {64, 96, 128} are all multiples of 4).
830 for (uint d = head_dim4 * 4 + simd_lane; d < head_dim; d += 32) {
831 dot += q[q_off + d] * float(k_cache[k_off + d]);
832 }
833 dot = simd_sum(dot);
834 if (simd_lane == 0) {
835 scores[s] = dot * fast::rsqrt(float(head_dim));
836 }
837 }
838 threadgroup_barrier(mem_flags::mem_threadgroup);
839
840 // Step 2: Softmax over scores (cooperative)
841 // Find max
842 float local_max = -INFINITY;
843 for (uint s = tid; s < seq_len; s += 256) {
844 local_max = max(local_max, scores[s]);
845 }
846 local_max = simd_max(local_max);
847 threadgroup float shared_max[8];
848 if (simd_lane == 0) shared_max[simd_id] = local_max;
849 threadgroup_barrier(mem_flags::mem_threadgroup);
850 if (tid == 0) {
851 float m = shared_max[0];
852 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
853 shared_max[0] = m;
854 }
855 threadgroup_barrier(mem_flags::mem_threadgroup);
856 float max_val = shared_max[0];
857
858 // Exp and sum
859 float local_sum = 0.0;
860 for (uint s = tid; s < seq_len; s += 256) {
861 scores[s] = fast::exp(scores[s] - max_val);
862 local_sum += scores[s];
863 }
864 local_sum = simd_sum(local_sum);
865 threadgroup float shared_sum[8];
866 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
867 threadgroup_barrier(mem_flags::mem_threadgroup);
868 if (tid == 0) {
869 float total = 0.0;
870 for (uint i = 0; i < 8; i++) total += shared_sum[i];
871 shared_sum[0] = 1.0 / total;
872 }
873 threadgroup_barrier(mem_flags::mem_threadgroup);
874 float inv_sum = shared_sum[0];
875
876 for (uint s = tid; s < seq_len; s += 256) {
877 scores[s] *= inv_sum;
878 }
879 threadgroup_barrier(mem_flags::mem_threadgroup);
880
881 // Step 3: Weighted sum of V: output = scores · V, half4 vectorized.
882 // Mirrors attention_batch prefill: each thread handles 4 d-dims via
883 // half4 loads, with scalar fallback for head_dim not divisible by 4.
884 uint seq_len4 = seq_len & ~3u; // largest multiple of 4 <= seq_len
885 uint v_stride = num_kv_heads * head_dim;
886 for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
887 uint d = d4 * 4;
888 float4 acc = float4(0.0);
889 uint v_base = kv_head * head_dim + d;
890 for (uint s = 0; s < seq_len4; s += 4) {
891 float sc0 = scores[s];
892 float sc1 = scores[s + 1];
893 float sc2 = scores[s + 2];
894 float sc3 = scores[s + 3];
895 acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
896 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
897 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
898 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
899 }
900 for (uint s = seq_len4; s < seq_len; s++) {
901 acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
902 }
903 *(device float4*)(output + q_off + d) = acc;
904 }
905 // Scalar fallback for remaining dimensions (unused for current models)
906 for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
907 float acc = 0.0;
908 uint v_base = kv_head * head_dim + d;
909 for (uint s = 0; s < seq_len; s++) {
910 acc += scores[s] * float(v_cache[s * v_stride + v_base]);
911 }
912 output[q_off + d] = acc;
913 }
914}
915
916// ── Batched prefill kernels ────────────────────────────────────────────
917// These kernels process M input vectors against the same weight matrix
918// in a single dispatch, converting mat-vec into mat-mat for better GPU
919// utilization during prompt prefill.
920
921// ── rms_norm_batch ─────────────────────────────────────────────────────
922// RMS normalization for a batch of vectors.
923// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
924// Grid: M threadgroups (one per token).
925kernel void rms_norm_batch(
926 device const float* input [[buffer(0)]], // [M, n]
927 device const float* weight [[buffer(1)]], // [n]
928 device float* output [[buffer(2)]], // [M, n]
929 constant uint& n [[buffer(3)]],
930 constant float& eps [[buffer(4)]],
931 constant uint& num_tokens [[buffer(5)]],
932 uint tgid [[threadgroup_position_in_grid]],
933 uint tid [[thread_index_in_threadgroup]])
934{
935 if (tgid >= num_tokens) return;
936
937 uint base = tgid * n;
938
939 float sum_sq = 0.0f;
940 for (uint i = tid; i < n; i += 256) {
941 float v = input[base + i];
942 sum_sq += v * v;
943 }
944
945 sum_sq = simd_sum(sum_sq);
946
947 threadgroup float shared[8];
948 uint simd_id = tid / 32;
949 uint simd_lane = tid % 32;
950 if (simd_lane == 0) {
951 shared[simd_id] = sum_sq;
952 }
953 threadgroup_barrier(mem_flags::mem_threadgroup);
954
955 if (tid == 0) {
956 float total = 0.0f;
957 for (uint i = 0; i < 8; i++) {
958 total += shared[i];
959 }
960 shared[0] = fast::rsqrt(total / float(n) + eps);
961 }
962 threadgroup_barrier(mem_flags::mem_threadgroup);
963
964 float inv_rms = shared[0];
965
966 for (uint i = tid; i < n; i += 256) {
967 output[base + i] = input[base + i] * inv_rms * weight[i];
968 }
969}
970
971// ── rope_batch ─────────────────────────────────────────────────────────
972// Rotary Position Embedding for a batch of vectors with different positions.
973// data layout: [M, num_heads * head_dim], positions: [M]
974// Each thread handles one (token, head, pair) combination.
975kernel void rope_batch(
976 device float* data [[buffer(0)]], // [M, num_heads * head_dim]
977 constant uint& num_heads [[buffer(1)]],
978 constant uint& head_dim [[buffer(2)]],
979 device const uint* positions [[buffer(3)]], // [M] position per token
980 constant float& theta [[buffer(4)]],
981 constant uint& num_tokens [[buffer(5)]],
982 uint id [[thread_position_in_grid]])
983{
984 uint half_dim = head_dim / 2;
985 uint pairs_per_token = num_heads * half_dim;
986 uint total = num_tokens * pairs_per_token;
987 if (id >= total) return;
988
989 uint token = id / pairs_per_token;
990 uint rem = id % pairs_per_token;
991 uint h = rem / half_dim;
992 uint i = rem % half_dim;
993 uint off = token * (num_heads * head_dim) + h * head_dim;
994
995 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
996 float angle = float(positions[token]) * freq;
997 float c = cos(angle);
998 float s = sin(angle);
999
1000 float x0 = data[off + 2 * i];
1001 float x1 = data[off + 2 * i + 1];
1002 data[off + 2 * i] = x0 * c - x1 * s;
1003 data[off + 2 * i + 1] = x0 * s + x1 * c;
1004}
1005
1006// ── silu_mul_fused_batch ───────────────────────────────────────────────
1007// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
1008// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
1009kernel void silu_mul_fused_batch(
1010 device const float* gate_up [[buffer(0)]], // [M, 2*n]
1011 device float* output [[buffer(1)]], // [M, n]
1012 constant uint& n [[buffer(2)]],
1013 constant uint& num_tokens [[buffer(3)]],
1014 uint id [[thread_position_in_grid]])
1015{
1016 uint total = num_tokens * n;
1017 if (id >= total) return;
1018 uint token = id / n;
1019 uint i = id % n;
1020 uint gu_base = token * 2 * n;
1021 float g = gate_up[gu_base + i];
1022 float u = gate_up[gu_base + n + i];
1023 output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
1024}
1025
1026// ── add_inplace_batch ──────────────────────────────────────────────────
1027// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
1028kernel void add_inplace_batch(
1029 device float* a [[buffer(0)]], // [M * n]
1030 device const float* b [[buffer(1)]], // [M * n]
1031 constant uint& total [[buffer(2)]], // M * n
1032 uint id [[thread_position_in_grid]])
1033{
1034 if (id >= total) return;
1035 a[id] += b[id];
1036}
1037
1038// ── copy_embedding_batch ───────────────────────────────────────────────
1039// Copy M embedding rows from embedding table to a contiguous batch buffer.
1040// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1041kernel void copy_embedding_batch(
1042 device const float* embed [[buffer(0)]], // [vocab_size, dim]
1043 device float* output [[buffer(1)]], // [M, dim]
1044 device const uint* tokens [[buffer(2)]], // [M]
1045 constant uint& dim [[buffer(3)]],
1046 constant uint& num_tokens [[buffer(4)]],
1047 uint id [[thread_position_in_grid]])
1048{
1049 uint total = num_tokens * dim;
1050 if (id >= total) return;
1051 uint token_idx = id / dim;
1052 uint d = id % dim;
1053 output[id] = embed[tokens[token_idx] * dim + d];
1054}
1055
1056// ── matmul_vec_batch ───────────────────────────────────────────────────
1057// Batched matrix-vector multiply: process M input vectors against the same
1058// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1059// Each threadgroup handles one (token, row_group) pair.
1060kernel void matmul_vec_batch(
1061 device const float* matrix [[buffer(0)]], // [rows, cols] weight
1062 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1063 device float* outputs [[buffer(2)]], // [M, rows] output batch
1064 constant uint& num_tokens [[buffer(3)]], // M
1065 constant uint& rows [[buffer(4)]],
1066 constant uint& cols [[buffer(5)]],
1067 uint tgid [[threadgroup_position_in_grid]],
1068 uint tid [[thread_index_in_threadgroup]],
1069 uint simd_lane [[thread_index_in_simdgroup]],
1070 uint simd_id [[simdgroup_index_in_threadgroup]])
1071{
1072 uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1073 uint token = tgid / row_tgs;
1074 uint tg_in_token = tgid % row_tgs;
1075 if (token >= num_tokens) return;
1076
1077 // Load this token's input vector into shared memory
1078 threadgroup float vec_tile[VEC_TILE_SIZE];
1079 device const float* input = inputs + token * cols;
1080 for (uint i = tid; i < cols; i += 256) {
1081 vec_tile[i] = input[i];
1082 }
1083 threadgroup_barrier(mem_flags::mem_threadgroup);
1084
1085 uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1086 if (row_base >= rows) return;
1087
1088 uint base0 = row_base * cols;
1089 uint base1 = (row_base + 1) * cols;
1090 uint base2 = (row_base + 2) * cols;
1091 uint base3 = (row_base + 3) * cols;
1092 uint base4 = (row_base + 4) * cols;
1093 uint base5 = (row_base + 5) * cols;
1094 uint base6 = (row_base + 6) * cols;
1095 uint base7 = (row_base + 7) * cols;
1096
1097 uint cols_vec4 = cols & ~127u;
1098 float4 sum4_0 = float4(0.0f);
1099 float4 sum4_1 = float4(0.0f);
1100 float4 sum4_2 = float4(0.0f);
1101 float4 sum4_3 = float4(0.0f);
1102 float4 sum4_4 = float4(0.0f);
1103 float4 sum4_5 = float4(0.0f);
1104 float4 sum4_6 = float4(0.0f);
1105 float4 sum4_7 = float4(0.0f);
1106
1107 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1108 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1109 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1110 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1111 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1112 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1113 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1114 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1115 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1116 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1117 }
1118
1119 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1120 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1121 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1122 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1123 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1124 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1125 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1126 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1127
1128 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1129 float vv = vec_tile[j];
1130 sum0 += matrix[base0 + j] * vv;
1131 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1132 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1133 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1134 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1135 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1136 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1137 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1138 }
1139
1140 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1141 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1142 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1143 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1144
1145 device float* output = outputs + token * rows;
1146 if (simd_lane == 0) {
1147 if (row_base < rows) output[row_base] = sum0;
1148 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1149 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1150 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1151 if (row_base + 4 < rows) output[row_base + 4] = sum4;
1152 if (row_base + 5 < rows) output[row_base + 5] = sum5;
1153 if (row_base + 6 < rows) output[row_base + 6] = sum6;
1154 if (row_base + 7 < rows) output[row_base + 7] = sum7;
1155 }
1156}
1157
1158// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1159// Batched Q8_0 matrix-vector multiply for M input vectors.
1160// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1161kernel void matmul_vec_q8_batch(
1162 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1163 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1164 device float* outputs [[buffer(2)]], // [M, rows] output batch
1165 constant uint& num_tokens [[buffer(3)]], // M
1166 constant uint& rows [[buffer(4)]],
1167 constant uint& cols [[buffer(5)]],
1168 uint tgid [[threadgroup_position_in_grid]],
1169 uint tid [[thread_index_in_threadgroup]],
1170 uint simd_lane [[thread_index_in_simdgroup]],
1171 uint simd_id [[simdgroup_index_in_threadgroup]])
1172{
1173 uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1174 uint token = tgid / row_tgs;
1175 uint tg_in_token = tgid % row_tgs;
1176 if (token >= num_tokens) return;
1177
1178 // Load this token's input vector into shared memory
1179 threadgroup float vec_tile[VEC_TILE_SIZE];
1180 device const float* input = inputs + token * cols;
1181 for (uint i = tid; i < cols; i += 256) {
1182 vec_tile[i] = input[i];
1183 }
1184 threadgroup_barrier(mem_flags::mem_threadgroup);
1185
1186 uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1187 if (row_base >= rows) return;
1188
1189 uint blocks_per_row = cols / 32;
1190 uint row_bytes = blocks_per_row * 34;
1191
1192 device const uchar* r0 = matrix + row_base * row_bytes;
1193 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1194 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1195 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1196
1197 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1198
1199 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1200 uint bb = blk * 34;
1201 uint vb = blk * 32;
1202
1203 float sc0 = float(*(device const half*)(r0 + bb));
1204 float sc1 = float(*(device const half*)(r1 + bb));
1205 float sc2 = float(*(device const half*)(r2 + bb));
1206 float sc3 = float(*(device const half*)(r3 + bb));
1207
1208 // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1209 // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1210 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1211 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1212 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1213 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1214
1215 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1216 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1217 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1218 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1219 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1220 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1221 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1222 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1223
1224 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1225 short4 _s = short4(SHORT4); \
1226 char2 _a = as_type<char2>(_s.x); \
1227 char2 _b = as_type<char2>(_s.y); \
1228 char2 _c = as_type<char2>(_s.z); \
1229 char2 _d = as_type<char2>(_s.w); \
1230 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1231 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1232 }
1233
1234 float4 f0, f1;
1235 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1236
1237 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1238 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1239 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1240 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1241
1242 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1243 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1244 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1245 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1246
1247 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1248 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1249 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1250 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1251
1252 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1253 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1254 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1255 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1256
1257 #undef Q8_UNPACK8
1258
1259 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1260 }
1261
1262 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1263 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1264
1265 device float* output = outputs + token * rows;
1266 if (simd_lane == 0) {
1267 if (row_base < rows) output[row_base] = sum0;
1268 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1269 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1270 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1271 }
1272}
1273
1274// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1275// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1276// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1277// the Q8_0 weight blocks are fetched once from device memory and reused for
1278// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1279// per-token dispatch).
1280//
1281// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1282// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1283// with simd_sum. Token vectors are read directly from device memory inside
1284// the block loop (not cached in shared memory) so intermediate_size up to
1285// 8192 fits without spilling threadgroup memory.
1286constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1287
1288kernel void matmul_q8_gemm_batch(
1289 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1290 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1291 device float* outputs [[buffer(2)]], // [M, rows] output batch
1292 constant uint& num_tokens [[buffer(3)]], // M
1293 constant uint& rows [[buffer(4)]],
1294 constant uint& cols [[buffer(5)]],
1295 uint2 tgid [[threadgroup_position_in_grid]],
1296 uint tid [[thread_index_in_threadgroup]],
1297 uint simd_lane [[thread_index_in_simdgroup]],
1298 uint simd_id [[simdgroup_index_in_threadgroup]])
1299{
1300 uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1301 uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1302 if (row_base >= rows || tok_base >= num_tokens) return;
1303
1304 // How many tokens in this tile are valid?
1305 uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1306
1307 uint blocks_per_row = cols / 32;
1308 uint row_bytes = blocks_per_row * 34;
1309
1310 device const uchar* r0 = matrix + row_base * row_bytes;
1311 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1312 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1313 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1314
1315 // Accumulators: 4 tokens × 4 rows per simdgroup.
1316 float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1317 float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1318 float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1319 float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1320
1321 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1322 uint bb = blk * 34;
1323 uint vb = blk * 32;
1324
1325 // ── Load weight data ONCE per block (reused across all tokens) ──
1326 float sc0 = float(*(device const half*)(r0 + bb));
1327 float sc1 = float(*(device const half*)(r1 + bb));
1328 float sc2 = float(*(device const half*)(r2 + bb));
1329 float sc3 = float(*(device const half*)(r3 + bb));
1330
1331 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1332 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1333 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1334 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1335
1336 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1337 short4 _s = short4(SHORT4); \
1338 char2 _a = as_type<char2>(_s.x); \
1339 char2 _b = as_type<char2>(_s.y); \
1340 char2 _c = as_type<char2>(_s.z); \
1341 char2 _d = as_type<char2>(_s.w); \
1342 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1343 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1344 }
1345
1346 // Unpack all 4 rows × 8 float4 weights (scaled). These live in
1347 // registers for the duration of the block and are dotted against
1348 // every token's vector tile.
1349 float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1350 float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1351 float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1352 float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1353
1354 Q8_UNPACK8(d0[0], w0_0, w0_1);
1355 Q8_UNPACK8(d0[1], w0_2, w0_3);
1356 Q8_UNPACK8(d0[2], w0_4, w0_5);
1357 Q8_UNPACK8(d0[3], w0_6, w0_7);
1358
1359 Q8_UNPACK8(d1[0], w1_0, w1_1);
1360 Q8_UNPACK8(d1[1], w1_2, w1_3);
1361 Q8_UNPACK8(d1[2], w1_4, w1_5);
1362 Q8_UNPACK8(d1[3], w1_6, w1_7);
1363
1364 Q8_UNPACK8(d2[0], w2_0, w2_1);
1365 Q8_UNPACK8(d2[1], w2_2, w2_3);
1366 Q8_UNPACK8(d2[2], w2_4, w2_5);
1367 Q8_UNPACK8(d2[3], w2_6, w2_7);
1368
1369 Q8_UNPACK8(d3[0], w3_0, w3_1);
1370 Q8_UNPACK8(d3[1], w3_2, w3_3);
1371 Q8_UNPACK8(d3[2], w3_4, w3_5);
1372 Q8_UNPACK8(d3[3], w3_6, w3_7);
1373
1374 #undef Q8_UNPACK8
1375
1376 // ── For each token, read vector and accumulate against shared weights ──
1377 // Token 0 (always valid: tok_count >= 1).
1378 {
1379 device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1380 float4 v0 = *(device const float4*)(a0);
1381 float4 v1 = *(device const float4*)(a0 + 4);
1382 float4 v2 = *(device const float4*)(a0 + 8);
1383 float4 v3 = *(device const float4*)(a0 + 12);
1384 float4 v4 = *(device const float4*)(a0 + 16);
1385 float4 v5 = *(device const float4*)(a0 + 20);
1386 float4 v6 = *(device const float4*)(a0 + 24);
1387 float4 v7 = *(device const float4*)(a0 + 28);
1388 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1389 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1390 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1391 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1392 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1393 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1394 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1395 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1396 s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1397 }
1398 // Token 1
1399 if (tok_count > 1) {
1400 device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1401 float4 v0 = *(device const float4*)(a1);
1402 float4 v1 = *(device const float4*)(a1 + 4);
1403 float4 v2 = *(device const float4*)(a1 + 8);
1404 float4 v3 = *(device const float4*)(a1 + 12);
1405 float4 v4 = *(device const float4*)(a1 + 16);
1406 float4 v5 = *(device const float4*)(a1 + 20);
1407 float4 v6 = *(device const float4*)(a1 + 24);
1408 float4 v7 = *(device const float4*)(a1 + 28);
1409 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1410 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1411 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1412 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1413 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1414 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1415 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1416 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1417 s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1418 }
1419 // Token 2
1420 if (tok_count > 2) {
1421 device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1422 float4 v0 = *(device const float4*)(a2);
1423 float4 v1 = *(device const float4*)(a2 + 4);
1424 float4 v2 = *(device const float4*)(a2 + 8);
1425 float4 v3 = *(device const float4*)(a2 + 12);
1426 float4 v4 = *(device const float4*)(a2 + 16);
1427 float4 v5 = *(device const float4*)(a2 + 20);
1428 float4 v6 = *(device const float4*)(a2 + 24);
1429 float4 v7 = *(device const float4*)(a2 + 28);
1430 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1431 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1432 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1433 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1434 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1435 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1436 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1437 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1438 s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1439 }
1440 // Token 3
1441 if (tok_count > 3) {
1442 device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1443 float4 v0 = *(device const float4*)(a3);
1444 float4 v1 = *(device const float4*)(a3 + 4);
1445 float4 v2 = *(device const float4*)(a3 + 8);
1446 float4 v3 = *(device const float4*)(a3 + 12);
1447 float4 v4 = *(device const float4*)(a3 + 16);
1448 float4 v5 = *(device const float4*)(a3 + 20);
1449 float4 v6 = *(device const float4*)(a3 + 24);
1450 float4 v7 = *(device const float4*)(a3 + 28);
1451 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1452 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1453 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1454 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1455 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1456 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1457 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1458 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1459 s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1460 }
1461 }
1462
1463 // simdgroup reduction
1464 s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1465 s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1466 s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1467 s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1468
1469 if (simd_lane == 0) {
1470 device float* o0 = outputs + (tok_base + 0) * rows;
1471 if (row_base < rows) o0[row_base] = s00;
1472 if (row_base + 1 < rows) o0[row_base + 1] = s01;
1473 if (row_base + 2 < rows) o0[row_base + 2] = s02;
1474 if (row_base + 3 < rows) o0[row_base + 3] = s03;
1475
1476 if (tok_count > 1) {
1477 device float* o1 = outputs + (tok_base + 1) * rows;
1478 if (row_base < rows) o1[row_base] = s10;
1479 if (row_base + 1 < rows) o1[row_base + 1] = s11;
1480 if (row_base + 2 < rows) o1[row_base + 2] = s12;
1481 if (row_base + 3 < rows) o1[row_base + 3] = s13;
1482 }
1483 if (tok_count > 2) {
1484 device float* o2 = outputs + (tok_base + 2) * rows;
1485 if (row_base < rows) o2[row_base] = s20;
1486 if (row_base + 1 < rows) o2[row_base + 1] = s21;
1487 if (row_base + 2 < rows) o2[row_base + 2] = s22;
1488 if (row_base + 3 < rows) o2[row_base + 3] = s23;
1489 }
1490 if (tok_count > 3) {
1491 device float* o3 = outputs + (tok_base + 3) * rows;
1492 if (row_base < rows) o3[row_base] = s30;
1493 if (row_base + 1 < rows) o3[row_base + 1] = s31;
1494 if (row_base + 2 < rows) o3[row_base + 2] = s32;
1495 if (row_base + 3 < rows) o3[row_base + 3] = s33;
1496 }
1497 }
1498}
1499
1500// ── matmul_q8_mma ──────────────────────────────────────────────────────
1501// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1502// simdgroup_matrix tiles (simdgroup_multiply_accumulate). This dispatches
1503// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1504// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1505//
1506// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1507// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1508// simdgroup_matrix<float, 8, 8> accumulator. Weight bytes are cooperatively
1509// dequantized into threadgroup memory once per block and reused by all
1510// simdgroups in the tile.
1511//
1512// Assumptions (verified in the dispatch helper, falls back otherwise):
1513// * cols % 32 == 0 (one Q8_0 block per K chunk)
1514// * rows % 16 == 0 (tile-aligned; true for all supported architectures)
1515// * num_tokens may be any value; partial row at the tile boundary is handled
1516// via a scratch copy path.
1517constant constexpr uint MMA_TOK_TILE = 16;
1518constant constexpr uint MMA_ROW_TILE = 16;
1519
1520kernel void matmul_q8_mma(
1521 device const uchar* matrix [[buffer(0)]], // Q8_0 [rows, cols/32 * 34]
1522 device const float* inputs [[buffer(1)]], // [M, cols]
1523 device float* outputs [[buffer(2)]], // [M, rows]
1524 constant uint& num_tokens [[buffer(3)]],
1525 constant uint& rows [[buffer(4)]],
1526 constant uint& cols [[buffer(5)]],
1527 uint2 tgid [[threadgroup_position_in_grid]],
1528 uint tid [[thread_index_in_threadgroup]],
1529 uint simd_id [[simdgroup_index_in_threadgroup]])
1530{
1531 uint row_base = tgid.x * MMA_ROW_TILE;
1532 uint tok_base = tgid.y * MMA_TOK_TILE;
1533 if (row_base >= rows || tok_base >= num_tokens) return;
1534
1535 // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1536 threadgroup float w_tile[MMA_ROW_TILE * 32];
1537 threadgroup float t_tile[MMA_TOK_TILE * 32];
1538
1539 // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1540 uint sg_tok_base = (simd_id / 2) * 8; // row within output tile (token dim)
1541 uint sg_row_base = (simd_id % 2) * 8; // col within output tile (row dim)
1542
1543 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1544
1545 uint blocks_per_row = cols / 32;
1546 uint row_bytes = blocks_per_row * 34;
1547
1548 for (uint blk = 0; blk < blocks_per_row; blk++) {
1549 // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1550 // 512 floats / 128 threads = 4 floats per thread.
1551 {
1552 uint base = tid * 4;
1553 for (uint ii = 0; ii < 4; ii++) {
1554 uint idx = base + ii;
1555 uint r = idx / 32;
1556 uint k = idx % 32;
1557 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1558 float sc = float(*(device const half*)rp);
1559 int ival = int(*(device const int8_t*)(rp + 2 + k));
1560 w_tile[r * 32 + k] = float(ival) * sc;
1561 }
1562 }
1563
1564 // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1565 {
1566 uint base = tid * 4;
1567 for (uint ii = 0; ii < 4; ii++) {
1568 uint idx = base + ii;
1569 uint m = idx / 32;
1570 uint k = idx % 32;
1571 uint tok = tok_base + m;
1572 t_tile[m * 32 + k] = (tok < num_tokens)
1573 ? inputs[tok * cols + blk * 32 + k]
1574 : 0.0f;
1575 }
1576 }
1577
1578 threadgroup_barrier(mem_flags::mem_threadgroup);
1579
1580 // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1581 // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k] (M×K, no transpose)
1582 // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k] (loaded transposed → K×R)
1583 // C[m, r] += A[m, k] * B[k, r]
1584 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1585 simdgroup_matrix<float, 8, 8> A, B;
1586 simdgroup_load(A,
1587 t_tile + sg_tok_base * 32 + k_sub * 8,
1588 32,
1589 ulong2(0, 0),
1590 false);
1591 simdgroup_load(B,
1592 w_tile + sg_row_base * 32 + k_sub * 8,
1593 32,
1594 ulong2(0, 0),
1595 true);
1596 simdgroup_multiply_accumulate(C, A, B, C);
1597 }
1598
1599 threadgroup_barrier(mem_flags::mem_threadgroup);
1600 }
1601
1602 // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1603 // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1604 uint out_tok = tok_base + sg_tok_base;
1605 uint out_row = row_base + sg_row_base;
1606 bool full_tok = (out_tok + 8 <= num_tokens);
1607 if (full_tok) {
1608 // Fast path: entire 8×8 sub-tile is in-bounds.
1609 simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1610 } else if (out_tok < num_tokens) {
1611 // Partial row at the last token tile: stage in per-simdgroup scratch
1612 // and scalar-copy the valid rows.
1613 threadgroup float scratch[4 * 64];
1614 simdgroup_store(C, scratch + simd_id * 64, 8);
1615 simdgroup_barrier(mem_flags::mem_threadgroup);
1616 uint lane = tid % 32;
1617 if (lane == 0) {
1618 uint valid = num_tokens - out_tok; // 1..7
1619 for (uint m = 0; m < valid; m++) {
1620 device float* dst = outputs + (out_tok + m) * rows + out_row;
1621 threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1622 dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1623 dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1624 }
1625 }
1626 }
1627}
1628
1629// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1630// Larger-tile variant of matmul_q8_mma for long-context prefill.
1631//
1632// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1633// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1634// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1635//
1636// simd_id = 2*sg_tok_idx + sg_row_half (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1637// output sub-tiles (tok, row):
1638// (sg_tok_idx*8, sg_row_half*16 + 0) -> C_a
1639// (sg_tok_idx*8, sg_row_half*16 + 8) -> C_b
1640//
1641// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1642// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1643// kernel — and halves the number of threadgroups vs the 16×16 tile.
1644//
1645// Assumptions (verified in dispatch helper, fallback otherwise):
1646// * cols % 32 == 0
1647// * rows % 32 == 0
1648constant constexpr uint MMA32_TOK_TILE = 32;
1649constant constexpr uint MMA32_ROW_TILE = 32;
1650
1651kernel void matmul_q8_mma32(
1652 device const uchar* matrix [[buffer(0)]],
1653 device const float* inputs [[buffer(1)]],
1654 device float* outputs [[buffer(2)]],
1655 constant uint& num_tokens [[buffer(3)]],
1656 constant uint& rows [[buffer(4)]],
1657 constant uint& cols [[buffer(5)]],
1658 uint2 tgid [[threadgroup_position_in_grid]],
1659 uint tid [[thread_index_in_threadgroup]],
1660 uint simd_id [[simdgroup_index_in_threadgroup]])
1661{
1662 uint row_base = tgid.x * MMA32_ROW_TILE;
1663 uint tok_base = tgid.y * MMA32_TOK_TILE;
1664 if (row_base >= rows || tok_base >= num_tokens) return;
1665
1666 // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1667 threadgroup float w_tile[MMA32_ROW_TILE * 32];
1668 threadgroup float t_tile[MMA32_TOK_TILE * 32];
1669
1670 uint sg_tok_idx = simd_id / 2; // 0..3
1671 uint sg_row_half = simd_id % 2; // 0..1
1672 uint sg_tok_base = sg_tok_idx * 8;
1673 uint sg_row_base_a = sg_row_half * 16 + 0;
1674 uint sg_row_base_b = sg_row_half * 16 + 8;
1675
1676 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1677 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1678
1679 uint blocks_per_row = cols / 32;
1680 uint row_bytes = blocks_per_row * 34;
1681
1682 for (uint blk = 0; blk < blocks_per_row; blk++) {
1683 // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1684 {
1685 uint base = tid * 4;
1686 for (uint ii = 0; ii < 4; ii++) {
1687 uint idx = base + ii;
1688 uint r = idx / 32;
1689 uint k = idx % 32;
1690 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1691 float sc = float(*(device const half*)rp);
1692 int ival = int(*(device const int8_t*)(rp + 2 + k));
1693 w_tile[r * 32 + k] = float(ival) * sc;
1694 }
1695 }
1696
1697 // Cooperative token tile load.
1698 {
1699 uint base = tid * 4;
1700 for (uint ii = 0; ii < 4; ii++) {
1701 uint idx = base + ii;
1702 uint m = idx / 32;
1703 uint k = idx % 32;
1704 uint tok = tok_base + m;
1705 t_tile[m * 32 + k] = (tok < num_tokens)
1706 ? inputs[tok * cols + blk * 32 + k]
1707 : 0.0f;
1708 }
1709 }
1710
1711 threadgroup_barrier(mem_flags::mem_threadgroup);
1712
1713 // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1714 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1715 simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1716 simdgroup_load(A,
1717 t_tile + sg_tok_base * 32 + k_sub * 8,
1718 32,
1719 ulong2(0, 0),
1720 false);
1721 simdgroup_load(B_a,
1722 w_tile + sg_row_base_a * 32 + k_sub * 8,
1723 32,
1724 ulong2(0, 0),
1725 true);
1726 simdgroup_load(B_b,
1727 w_tile + sg_row_base_b * 32 + k_sub * 8,
1728 32,
1729 ulong2(0, 0),
1730 true);
1731 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1732 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1733 }
1734
1735 threadgroup_barrier(mem_flags::mem_threadgroup);
1736 }
1737
1738 // Store both 8×8 accumulators. rows is always MMA32_ROW_TILE-aligned
1739 // (verified in dispatch), so full simdgroup_store is safe for the row
1740 // dimension; only the last token tile may be partial.
1741 uint out_tok = tok_base + sg_tok_base;
1742 uint out_row_a = row_base + sg_row_base_a;
1743 uint out_row_b = row_base + sg_row_base_b;
1744 bool full_tok = (out_tok + 8 <= num_tokens);
1745 if (full_tok) {
1746 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1747 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1748 } else if (out_tok < num_tokens) {
1749 threadgroup float scratch[8 * 2 * 64]; // 8 simdgroups × 2 accs × 64 floats
1750 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1751 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1752 simdgroup_barrier(mem_flags::mem_threadgroup);
1753 uint lane = tid % 32;
1754 if (lane == 0) {
1755 uint valid = num_tokens - out_tok; // 1..7
1756 for (uint m = 0; m < valid; m++) {
1757 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1758 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1759 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1760 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1761 for (uint j = 0; j < 8; j++) {
1762 dst_a[j] = src_a[j];
1763 dst_b[j] = src_b[j];
1764 }
1765 }
1766 }
1767 }
1768}
1769
1770// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1771// FP16 threadgroup-tile variant of matmul_q8_mma32.
1772//
1773// Stores dequantized weights and token inputs as `half` in threadgroup
1774// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1775// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1776// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1777// representation preserves the full quantized dynamic range. Token
1778// activations stay numerically safe because the subsequent
1779// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1780//
1781// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1782// accumulators each. Primary win vs mma32 is occupancy at moderate
1783// prefill lengths where the GPU is wave-starved.
1784kernel void matmul_q8_mma32_h(
1785 device const uchar* matrix [[buffer(0)]],
1786 device const float* inputs [[buffer(1)]],
1787 device float* outputs [[buffer(2)]],
1788 constant uint& num_tokens [[buffer(3)]],
1789 constant uint& rows [[buffer(4)]],
1790 constant uint& cols [[buffer(5)]],
1791 uint2 tgid [[threadgroup_position_in_grid]],
1792 uint tid [[thread_index_in_threadgroup]],
1793 uint simd_id [[simdgroup_index_in_threadgroup]])
1794{
1795 uint row_base = tgid.x * MMA32_ROW_TILE;
1796 uint tok_base = tgid.y * MMA32_TOK_TILE;
1797 if (row_base >= rows || tok_base >= num_tokens) return;
1798
1799 // 32×32 half tiles — 2 KB each, 4 KB total.
1800 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1801 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1802
1803 uint sg_tok_idx = simd_id / 2;
1804 uint sg_row_half = simd_id % 2;
1805 uint sg_tok_base = sg_tok_idx * 8;
1806 uint sg_row_base_a = sg_row_half * 16 + 0;
1807 uint sg_row_base_b = sg_row_half * 16 + 8;
1808
1809 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1810 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1811
1812 uint blocks_per_row = cols / 32;
1813 uint row_bytes = blocks_per_row * 34;
1814
1815 for (uint blk = 0; blk < blocks_per_row; blk++) {
1816 // Cooperative weight dequantization to FP16.
1817 {
1818 uint base = tid * 4;
1819 for (uint ii = 0; ii < 4; ii++) {
1820 uint idx = base + ii;
1821 uint r = idx / 32;
1822 uint k = idx % 32;
1823 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1824 float sc = float(*(device const half*)rp);
1825 int ival = int(*(device const int8_t*)(rp + 2 + k));
1826 w_tile[r * 32 + k] = half(float(ival) * sc);
1827 }
1828 }
1829
1830 // Cooperative token tile load (f32 → f16 narrowing).
1831 {
1832 uint base = tid * 4;
1833 for (uint ii = 0; ii < 4; ii++) {
1834 uint idx = base + ii;
1835 uint m = idx / 32;
1836 uint k = idx % 32;
1837 uint tok = tok_base + m;
1838 t_tile[m * 32 + k] = (tok < num_tokens)
1839 ? half(inputs[tok * cols + blk * 32 + k])
1840 : half(0);
1841 }
1842 }
1843
1844 threadgroup_barrier(mem_flags::mem_threadgroup);
1845
1846 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1847 simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1848 simdgroup_load(A,
1849 t_tile + sg_tok_base * 32 + k_sub * 8,
1850 32,
1851 ulong2(0, 0),
1852 false);
1853 simdgroup_load(B_a,
1854 w_tile + sg_row_base_a * 32 + k_sub * 8,
1855 32,
1856 ulong2(0, 0),
1857 true);
1858 simdgroup_load(B_b,
1859 w_tile + sg_row_base_b * 32 + k_sub * 8,
1860 32,
1861 ulong2(0, 0),
1862 true);
1863 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1864 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1865 }
1866
1867 threadgroup_barrier(mem_flags::mem_threadgroup);
1868 }
1869
1870 uint out_tok = tok_base + sg_tok_base;
1871 uint out_row_a = row_base + sg_row_base_a;
1872 uint out_row_b = row_base + sg_row_base_b;
1873 bool full_tok = (out_tok + 8 <= num_tokens);
1874 if (full_tok) {
1875 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1876 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1877 } else if (out_tok < num_tokens) {
1878 threadgroup float scratch[8 * 2 * 64];
1879 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1880 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1881 simdgroup_barrier(mem_flags::mem_threadgroup);
1882 uint lane = tid % 32;
1883 if (lane == 0) {
1884 uint valid = num_tokens - out_tok;
1885 for (uint m = 0; m < valid; m++) {
1886 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1887 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1888 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1889 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1890 for (uint j = 0; j < 8; j++) {
1891 dst_a[j] = src_a[j];
1892 dst_b[j] = src_b[j];
1893 }
1894 }
1895 }
1896 }
1897}
1898
1899// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1900// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1901//
1902// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1903// 4 simdgroups × **2×2 grid** of 8×8 accumulators each. Per simdgroup:
1904// C_00 (tok 0..8, row 0..8) C_01 (tok 0..8, row 8..16)
1905// C_10 (tok 8..16, row 0..8) C_11 (tok 8..16, row 8..16)
1906// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1907//
1908// Per K_sub iteration: load two A fragments and two B fragments, then run
1909// **four** MMA instructions reusing A_top with both B's and A_bot with
1910// both B's. That's double the FLOP-per-simdgroup-load compared to the
1911// 2-accumulator kernel and halves the thread count per threadgroup (128
1912// threads), which often improves occupancy on Apple GPUs where the
1913// concurrent-thread budget is the tighter limit than shared-memory size.
1914kernel void matmul_q8_mma32_h4(
1915 device const uchar* matrix [[buffer(0)]],
1916 device const float* inputs [[buffer(1)]],
1917 device float* outputs [[buffer(2)]],
1918 constant uint& num_tokens [[buffer(3)]],
1919 constant uint& rows [[buffer(4)]],
1920 constant uint& cols [[buffer(5)]],
1921 uint2 tgid [[threadgroup_position_in_grid]],
1922 uint tid [[thread_index_in_threadgroup]],
1923 uint simd_id [[simdgroup_index_in_threadgroup]])
1924{
1925 uint row_base = tgid.x * MMA32_ROW_TILE;
1926 uint tok_base = tgid.y * MMA32_TOK_TILE;
1927 if (row_base >= rows || tok_base >= num_tokens) return;
1928
1929 // 32×32 FP16 tiles, 4 KB total.
1930 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1931 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1932
1933 // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1934 uint sg_tok_q = simd_id / 2; // 0..1
1935 uint sg_row_q = simd_id % 2; // 0..1
1936 uint sg_tok_base = sg_tok_q * 16;
1937 uint sg_row_base = sg_row_q * 16;
1938
1939 simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1940 simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1941 simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1942 simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1943
1944 uint blocks_per_row = cols / 32;
1945 uint row_bytes = blocks_per_row * 34;
1946
1947 for (uint blk = 0; blk < blocks_per_row; blk++) {
1948 // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1949 {
1950 uint base = tid * 8;
1951 for (uint ii = 0; ii < 8; ii++) {
1952 uint idx = base + ii;
1953 uint r = idx / 32;
1954 uint k = idx % 32;
1955 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1956 float sc = float(*(device const half*)rp);
1957 int ival = int(*(device const int8_t*)(rp + 2 + k));
1958 w_tile[r * 32 + k] = half(float(ival) * sc);
1959 }
1960 }
1961
1962 // Cooperative token tile load.
1963 {
1964 uint base = tid * 8;
1965 for (uint ii = 0; ii < 8; ii++) {
1966 uint idx = base + ii;
1967 uint m = idx / 32;
1968 uint k = idx % 32;
1969 uint tok = tok_base + m;
1970 t_tile[m * 32 + k] = (tok < num_tokens)
1971 ? half(inputs[tok * cols + blk * 32 + k])
1972 : half(0);
1973 }
1974 }
1975
1976 threadgroup_barrier(mem_flags::mem_threadgroup);
1977
1978 // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1979 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1980 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1981 simdgroup_load(A_top,
1982 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1983 32,
1984 ulong2(0, 0),
1985 false);
1986 simdgroup_load(A_bot,
1987 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1988 32,
1989 ulong2(0, 0),
1990 false);
1991 simdgroup_load(B_lo,
1992 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1993 32,
1994 ulong2(0, 0),
1995 true);
1996 simdgroup_load(B_hi,
1997 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1998 32,
1999 ulong2(0, 0),
2000 true);
2001 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2002 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2003 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2004 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2005 }
2006
2007 threadgroup_barrier(mem_flags::mem_threadgroup);
2008 }
2009
2010 // Store 4 output tiles. Full-tile fast path assumes full 16×16 valid.
2011 uint out_tok_top = tok_base + sg_tok_base + 0;
2012 uint out_tok_bot = tok_base + sg_tok_base + 8;
2013 uint out_row_lo = row_base + sg_row_base + 0;
2014 uint out_row_hi = row_base + sg_row_base + 8;
2015 bool full = (out_tok_bot + 8 <= num_tokens);
2016 if (full) {
2017 simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
2018 simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
2019 simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
2020 simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
2021 } else {
2022 // Partial-token fallback via per-simdgroup scratch.
2023 threadgroup float scratch[4 * 4 * 64]; // 4 simdgroups × 4 accs × 64
2024 uint sg_base = simd_id * 256;
2025 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2026 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2027 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2028 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2029 simdgroup_barrier(mem_flags::mem_threadgroup);
2030 uint lane = tid % 32;
2031 if (lane == 0) {
2032 for (uint m = 0; m < 8; m++) {
2033 uint t_top = out_tok_top + m;
2034 if (t_top < num_tokens) {
2035 device float* dst0 = outputs + t_top * rows + out_row_lo;
2036 device float* dst1 = outputs + t_top * rows + out_row_hi;
2037 threadgroup const float* src0 = scratch + sg_base + 0 + m * 8;
2038 threadgroup const float* src1 = scratch + sg_base + 64 + m * 8;
2039 for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2040 }
2041 uint t_bot = out_tok_bot + m;
2042 if (t_bot < num_tokens) {
2043 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2044 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2045 threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2046 threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2047 for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2048 }
2049 }
2050 }
2051 }
2052}
2053
2054// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2055// All-half MMA variant of matmul_q8_mma32_h4.
2056//
2057// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2058// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2059// rate (dual-issue FMA), so if Q8_0 precision holds through half
2060// accumulation this kernel can double the effective FLOP throughput on
2061// matmul-bound prefill.
2062//
2063// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2064// activations at each layer are bounded (post-RMSNorm ≈ O(1)). Summing
2065// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2066// precision on extreme values, but the inputs are already quantized so the
2067// per-product error floor is higher than the half-precision rounding error.
2068// We verify correctness on 135M / 1B / 3B before enabling.
2069kernel void matmul_q8_mma32_hh4(
2070 device const uchar* matrix [[buffer(0)]],
2071 device const float* inputs [[buffer(1)]],
2072 device float* outputs [[buffer(2)]],
2073 constant uint& num_tokens [[buffer(3)]],
2074 constant uint& rows [[buffer(4)]],
2075 constant uint& cols [[buffer(5)]],
2076 uint2 tgid [[threadgroup_position_in_grid]],
2077 uint tid [[thread_index_in_threadgroup]],
2078 uint simd_id [[simdgroup_index_in_threadgroup]])
2079{
2080 uint row_base = tgid.x * MMA32_ROW_TILE;
2081 uint tok_base = tgid.y * MMA32_TOK_TILE;
2082 if (row_base >= rows || tok_base >= num_tokens) return;
2083
2084 threadgroup half w_tile[MMA32_ROW_TILE * 32];
2085 threadgroup half t_tile[MMA32_TOK_TILE * 32];
2086
2087 uint sg_tok_q = simd_id / 2;
2088 uint sg_row_q = simd_id % 2;
2089 uint sg_tok_base = sg_tok_q * 16;
2090 uint sg_row_base = sg_row_q * 16;
2091
2092 simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2093 simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2094 simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2095 simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2096
2097 uint blocks_per_row = cols / 32;
2098 uint row_bytes = blocks_per_row * 34;
2099
2100 for (uint blk = 0; blk < blocks_per_row; blk++) {
2101 {
2102 uint base = tid * 8;
2103 for (uint ii = 0; ii < 8; ii++) {
2104 uint idx = base + ii;
2105 uint r = idx / 32;
2106 uint k = idx % 32;
2107 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2108 float sc = float(*(device const half*)rp);
2109 int ival = int(*(device const int8_t*)(rp + 2 + k));
2110 w_tile[r * 32 + k] = half(float(ival) * sc);
2111 }
2112 }
2113 {
2114 uint base = tid * 8;
2115 for (uint ii = 0; ii < 8; ii++) {
2116 uint idx = base + ii;
2117 uint m = idx / 32;
2118 uint k = idx % 32;
2119 uint tok = tok_base + m;
2120 t_tile[m * 32 + k] = (tok < num_tokens)
2121 ? half(inputs[tok * cols + blk * 32 + k])
2122 : half(0);
2123 }
2124 }
2125 threadgroup_barrier(mem_flags::mem_threadgroup);
2126
2127 for (uint k_sub = 0; k_sub < 4; k_sub++) {
2128 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2129 simdgroup_load(A_top,
2130 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2131 32, ulong2(0, 0), false);
2132 simdgroup_load(A_bot,
2133 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2134 32, ulong2(0, 0), false);
2135 simdgroup_load(B_lo,
2136 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2137 32, ulong2(0, 0), true);
2138 simdgroup_load(B_hi,
2139 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2140 32, ulong2(0, 0), true);
2141 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2142 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2143 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2144 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2145 }
2146
2147 threadgroup_barrier(mem_flags::mem_threadgroup);
2148 }
2149
2150 // Store half accumulators via scratch (must widen to f32 for device output).
2151 uint out_tok_top = tok_base + sg_tok_base + 0;
2152 uint out_tok_bot = tok_base + sg_tok_base + 8;
2153 uint out_row_lo = row_base + sg_row_base + 0;
2154 uint out_row_hi = row_base + sg_row_base + 8;
2155
2156 threadgroup half scratch[4 * 4 * 64];
2157 uint sg_base = simd_id * 256;
2158 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2159 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2160 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2161 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2162 simdgroup_barrier(mem_flags::mem_threadgroup);
2163 uint lane = tid % 32;
2164 if (lane == 0) {
2165 for (uint m = 0; m < 8; m++) {
2166 uint t_top = out_tok_top + m;
2167 if (t_top < num_tokens) {
2168 device float* dst0 = outputs + t_top * rows + out_row_lo;
2169 device float* dst1 = outputs + t_top * rows + out_row_hi;
2170 threadgroup const half* src0 = scratch + sg_base + 0 + m * 8;
2171 threadgroup const half* src1 = scratch + sg_base + 64 + m * 8;
2172 for (uint j = 0; j < 8; j++) {
2173 dst0[j] = float(src0[j]);
2174 dst1[j] = float(src1[j]);
2175 }
2176 }
2177 uint t_bot = out_tok_bot + m;
2178 if (t_bot < num_tokens) {
2179 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2180 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2181 threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2182 threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2183 for (uint j = 0; j < 8; j++) {
2184 dst2[j] = float(src2[j]);
2185 dst3[j] = float(src3[j]);
2186 }
2187 }
2188 }
2189 }
2190}
2191
2192// ── add_bias_batch ─────────────────────────────────────────────────────
2193// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2194// Used for Qwen2 QKV bias after the fused qkv matmul.
2195// out[token, i] += bias[i] for i in 0..rows, token in 0..num_tokens
2196kernel void add_bias_batch(
2197 device float* out [[buffer(0)]], // [num_tokens, rows]
2198 device const float* bias [[buffer(1)]], // [rows]
2199 constant uint& num_tokens [[buffer(2)]],
2200 constant uint& rows [[buffer(3)]],
2201 uint id [[thread_position_in_grid]])
2202{
2203 uint total = num_tokens * rows;
2204 if (id >= total) return;
2205 uint i = id % rows;
2206 out[id] += bias[i];
2207}
2208
2209// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2210// Batched Q4_0 matrix-vector multiply for M input vectors.
2211// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2212kernel void matmul_vec_q4_batch(
2213 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes [rows, cols]
2214 device const float* inputs [[buffer(1)]], // [M, cols] input batch
2215 device float* outputs [[buffer(2)]], // [M, rows] output batch
2216 constant uint& num_tokens [[buffer(3)]], // M
2217 constant uint& rows [[buffer(4)]],
2218 constant uint& cols [[buffer(5)]],
2219 uint tgid [[threadgroup_position_in_grid]],
2220 uint tid [[thread_index_in_threadgroup]],
2221 uint simd_lane [[thread_index_in_simdgroup]],
2222 uint simd_id [[simdgroup_index_in_threadgroup]])
2223{
2224 uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2225 uint token = tgid / row_tgs;
2226 uint tg_in_token = tgid % row_tgs;
2227 if (token >= num_tokens) return;
2228
2229 threadgroup float vec_tile[VEC_TILE_SIZE];
2230 device const float* input = inputs + token * cols;
2231 for (uint i = tid; i < cols; i += 256) {
2232 vec_tile[i] = input[i];
2233 }
2234 threadgroup_barrier(mem_flags::mem_threadgroup);
2235
2236 uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2237 if (row_base >= rows) return;
2238
2239 uint blocks_per_row = cols / 32;
2240 uint row_bytes = blocks_per_row * 18;
2241
2242 device const uchar* r0 = matrix + row_base * row_bytes;
2243 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2244 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2245 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2246
2247 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2248
2249 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2250 uint bb = blk * 18;
2251 uint vb = blk * 32;
2252
2253 float sc0 = float(*(device const half*)(r0 + bb));
2254 float sc1 = float(*(device const half*)(r1 + bb));
2255 float sc2 = float(*(device const half*)(r2 + bb));
2256 float sc3 = float(*(device const half*)(r3 + bb));
2257
2258 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2259 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2260 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2261 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2262
2263 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2264 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2265 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2266 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2267 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2268 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2269 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2270 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2271
2272 float bd0=0, bd1=0, bd2=0, bd3=0;
2273 uchar4 b;
2274
2275 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2276 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2277 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2278 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2279
2280 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2281 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2282 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2283 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2284
2285 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2286 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2287 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2288 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2289
2290 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2291 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2292 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2293 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2294
2295 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2296 }
2297
2298 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2299 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2300
2301 device float* output = outputs + token * rows;
2302 if (simd_lane == 0) {
2303 if (row_base < rows) output[row_base] = sum0;
2304 if (row_base + 1 < rows) output[row_base + 1] = sum1;
2305 if (row_base + 2 < rows) output[row_base + 2] = sum2;
2306 if (row_base + 3 < rows) output[row_base + 3] = sum3;
2307 }
2308}
2309
2310// ── copy_kv_batch ─────────────────────────────────────────────────────
2311// Copy K or V from a strided batch QKV buffer to the KV cache.
2312// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2313// dst layout: contiguous [max_seq, kv_dim] cache.
2314kernel void copy_kv_batch(
2315 device const float* src [[buffer(0)]], // batch QKV buffer (f32)
2316 device half* dst [[buffer(1)]], // KV cache (f16)
2317 constant uint& M [[buffer(2)]], // num tokens in batch
2318 constant uint& kv_dim [[buffer(3)]], // elements per KV vector
2319 constant uint& base_pos [[buffer(4)]], // starting position in cache
2320 constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2321 constant uint& src_offset [[buffer(6)]], // float offset within each src row
2322 uint id [[thread_position_in_grid]])
2323{
2324 uint total = M * kv_dim;
2325 if (id >= total) return;
2326 uint token = id / kv_dim;
2327 uint d = id % kv_dim;
2328 uint dst_off = (base_pos + token) * kv_dim + d;
2329 uint src_off = token * src_stride + src_offset + d;
2330 dst[dst_off] = half(src[src_off]);
2331}
2332
2333// ── attention_batch ───────────────────────────────────────────────────
2334// Batched causal attention for prefill. Processes M tokens in one dispatch.
2335// Each threadgroup handles one (token, head) pair.
2336// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2337// Causal masking: token i can only attend to positions 0..base_pos+i.
2338kernel void attention_batch(
2339 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2340 device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
2341 device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
2342 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2343 constant uint& M [[buffer(4)]], // num tokens in batch
2344 constant uint& base_pos [[buffer(5)]], // starting position in KV cache
2345 constant uint& num_heads [[buffer(6)]],
2346 constant uint& num_kv_heads [[buffer(7)]],
2347 constant uint& head_dim [[buffer(8)]],
2348 constant uint& q_stride [[buffer(9)]], // floats per row in q_batch
2349 uint tgid [[threadgroup_position_in_grid]],
2350 uint tid [[thread_index_in_threadgroup]],
2351 uint simd_lane [[thread_index_in_simdgroup]],
2352 uint simd_id [[simdgroup_index_in_threadgroup]])
2353{
2354 // Grid: M * num_heads threadgroups
2355 uint token_idx = tgid / num_heads;
2356 uint head = tgid % num_heads;
2357 if (token_idx >= M) return;
2358
2359 uint kv_head = head / (num_heads / num_kv_heads);
2360 uint seq_len = base_pos + token_idx + 1; // causal: see positions 0..base_pos+token_idx
2361
2362 // Q offset uses strided layout (from batch QKV buffer)
2363 uint q_off = token_idx * q_stride + head * head_dim;
2364 // Output is contiguous [M, num_heads * head_dim]
2365 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2366
2367 // Shared memory for attention scores — sized to the effective max_seq_len
2368 // (4096 for all supported models) so long-context attention doesn't overflow.
2369 threadgroup float scores[ATTN_SCORES_SIZE];
2370
2371 // Step 1: Q * K^T with simdgroup reduction
2372 for (uint s = simd_id; s < seq_len; s += 8) {
2373 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2374 float dot = 0.0;
2375 for (uint d = simd_lane; d < head_dim; d += 32) {
2376 dot += q_batch[q_off + d] * k_cache[k_off + d];
2377 }
2378 dot = simd_sum(dot);
2379 if (simd_lane == 0) {
2380 scores[s] = dot * fast::rsqrt(float(head_dim));
2381 }
2382 }
2383 threadgroup_barrier(mem_flags::mem_threadgroup);
2384
2385 // Step 2: Softmax (cooperative)
2386 float local_max = -INFINITY;
2387 for (uint s = tid; s < seq_len; s += 256) {
2388 local_max = max(local_max, scores[s]);
2389 }
2390 local_max = simd_max(local_max);
2391 threadgroup float shared_max[8];
2392 if (simd_lane == 0) shared_max[simd_id] = local_max;
2393 threadgroup_barrier(mem_flags::mem_threadgroup);
2394 if (tid == 0) {
2395 float m = shared_max[0];
2396 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2397 shared_max[0] = m;
2398 }
2399 threadgroup_barrier(mem_flags::mem_threadgroup);
2400 float max_val = shared_max[0];
2401
2402 float local_sum = 0.0;
2403 for (uint s = tid; s < seq_len; s += 256) {
2404 scores[s] = fast::exp(scores[s] - max_val);
2405 local_sum += scores[s];
2406 }
2407 local_sum = simd_sum(local_sum);
2408 threadgroup float shared_sum[8];
2409 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2410 threadgroup_barrier(mem_flags::mem_threadgroup);
2411 if (tid == 0) {
2412 float total = 0.0;
2413 for (uint i = 0; i < 8; i++) total += shared_sum[i];
2414 shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2415 }
2416 threadgroup_barrier(mem_flags::mem_threadgroup);
2417 float inv_sum = shared_sum[0];
2418 for (uint s = tid; s < seq_len; s += 256) {
2419 scores[s] *= inv_sum;
2420 }
2421 threadgroup_barrier(mem_flags::mem_threadgroup);
2422
2423 // Step 3: scores * V using float4 vectorized loads
2424 // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2425 // This is much better than the scalar version where only 64 of 256 threads are active.
2426 uint v_stride = num_kv_heads * head_dim;
2427 uint head_dim4 = head_dim / 4;
2428 for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2429 uint d = d4 * 4;
2430 float4 acc = float4(0.0);
2431 uint v_base = kv_head * head_dim + d;
2432 uint seq_len4 = seq_len & ~3u;
2433 for (uint s = 0; s < seq_len4; s += 4) {
2434 float sc0 = scores[s];
2435 float sc1 = scores[s + 1];
2436 float sc2 = scores[s + 2];
2437 float sc3 = scores[s + 3];
2438 acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
2439 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
2440 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
2441 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
2442 }
2443 for (uint s = seq_len4; s < seq_len; s++) {
2444 acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
2445 }
2446 *(device float4*)(output_batch + out_off + d) = acc;
2447 }
2448 // Handle remaining dimensions not divisible by 4 (scalar fallback)
2449 for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2450 float acc = 0.0;
2451 uint v_base = kv_head * head_dim + d;
2452 for (uint s = 0; s < seq_len; s++) {
2453 acc += scores[s] * v_cache[s * v_stride + v_base];
2454 }
2455 output_batch[out_off + d] = acc;
2456 }
2457}
2458
2459// ── attention_flash_batch ─────────────────────────────────────────────
2460// Streaming attention with online softmax. Same grid as attention_batch
2461// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2462// matrix is never materialized. K/V positions are processed in a tile of
2463// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2464// the standard flash-attention recurrence:
2465//
2466// m_new = max(m_old, tile_max)
2467// alpha = exp(m_old - m_new)
2468// l_new = alpha * l_old + sum(exp(S - m_new))
2469// O_new = alpha * O_old + sum(exp(S - m_new) * V)
2470// O_final = O / l
2471//
2472// This removes the `scores[2048]` cap in attention_batch (which silently
2473// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2474// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2475//
2476// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2477constant constexpr uint FLASH_K_TILE = 32;
2478constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2479
2480kernel void attention_flash_batch(
2481 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2482 device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
2483 device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
2484 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2485 constant uint& M [[buffer(4)]],
2486 constant uint& base_pos [[buffer(5)]],
2487 constant uint& num_heads [[buffer(6)]],
2488 constant uint& num_kv_heads [[buffer(7)]],
2489 constant uint& head_dim [[buffer(8)]],
2490 constant uint& q_stride [[buffer(9)]],
2491 uint tgid [[threadgroup_position_in_grid]],
2492 uint tid [[thread_index_in_threadgroup]],
2493 uint simd_lane [[thread_index_in_simdgroup]],
2494 uint simd_id [[simdgroup_index_in_threadgroup]])
2495{
2496 uint token_idx = tgid / num_heads;
2497 uint head = tgid % num_heads;
2498 if (token_idx >= M) return;
2499
2500 uint kv_head = head / (num_heads / num_kv_heads);
2501 uint seq_len = base_pos + token_idx + 1; // causal: attend to [0, base_pos + token_idx]
2502 uint q_off = token_idx * q_stride + head * head_dim;
2503 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2504
2505 // Threadgroup state:
2506 // q_sh: Q vector for this (token, head), loaded once
2507 // o_sh: running output vector, updated each K tile
2508 // scores_sh: scores for the current K tile only
2509 // stats: [running max, running sum] (see flash-attention recurrence)
2510 threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2511 threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2512 threadgroup float scores_sh[FLASH_K_TILE];
2513 threadgroup float stats[2];
2514 threadgroup float sg_scratch[8]; // simdgroup-level reduction buffer
2515
2516 // --- Load Q (one row) and zero the running O ---
2517 for (uint d = tid; d < head_dim; d += 256) {
2518 q_sh[d] = q_batch[q_off + d];
2519 o_sh[d] = 0.0f;
2520 }
2521 if (tid == 0) {
2522 stats[0] = -INFINITY;
2523 stats[1] = 0.0f;
2524 }
2525 threadgroup_barrier(mem_flags::mem_threadgroup);
2526
2527 float scale = fast::rsqrt(float(head_dim));
2528 uint v_stride = num_kv_heads * head_dim;
2529 uint v_base = kv_head * head_dim;
2530
2531 // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2532 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2533 uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2534
2535 // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2536 // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2537 for (uint ti = simd_id; ti < tile_n; ti += 8) {
2538 uint k_off = (kv_base + ti) * v_stride + v_base; // same layout as V stride
2539 float dot = 0.0f;
2540 for (uint d = simd_lane; d < head_dim; d += 32) {
2541 dot += q_sh[d] * k_cache[k_off + d];
2542 }
2543 dot = simd_sum(dot);
2544 if (simd_lane == 0) {
2545 scores_sh[ti] = dot * scale;
2546 }
2547 }
2548 threadgroup_barrier(mem_flags::mem_threadgroup);
2549
2550 // [2] Tile max via cooperative reduction.
2551 float local_max = -INFINITY;
2552 for (uint s = tid; s < tile_n; s += 256) {
2553 local_max = max(local_max, scores_sh[s]);
2554 }
2555 local_max = simd_max(local_max);
2556 if (simd_lane == 0) {
2557 sg_scratch[simd_id] = local_max;
2558 }
2559 threadgroup_barrier(mem_flags::mem_threadgroup);
2560 // [3] Merge with running max, compute alpha, rescale running l.
2561 float m_new;
2562 float alpha;
2563 if (tid == 0) {
2564 float tile_max = sg_scratch[0];
2565 for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2566 float m_old = stats[0];
2567 m_new = max(m_old, tile_max);
2568 // First iteration: m_old = -inf → alpha = 0 (reset O).
2569 alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2570 stats[0] = m_new;
2571 stats[1] *= alpha;
2572 // Broadcast via sg_scratch.
2573 sg_scratch[0] = alpha;
2574 sg_scratch[1] = m_new;
2575 }
2576 threadgroup_barrier(mem_flags::mem_threadgroup);
2577 alpha = sg_scratch[0];
2578 m_new = sg_scratch[1];
2579
2580 // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2581 for (uint d = tid; d < head_dim; d += 256) {
2582 o_sh[d] *= alpha;
2583 }
2584 for (uint s = tid; s < tile_n; s += 256) {
2585 scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2586 }
2587 threadgroup_barrier(mem_flags::mem_threadgroup);
2588
2589 // [5] Tile sum → update running l.
2590 float local_sum = 0.0f;
2591 for (uint s = tid; s < tile_n; s += 256) {
2592 local_sum += scores_sh[s];
2593 }
2594 local_sum = simd_sum(local_sum);
2595 if (simd_lane == 0) {
2596 sg_scratch[simd_id] = local_sum;
2597 }
2598 threadgroup_barrier(mem_flags::mem_threadgroup);
2599 if (tid == 0) {
2600 float tile_sum = 0.0f;
2601 for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2602 stats[1] += tile_sum;
2603 }
2604 threadgroup_barrier(mem_flags::mem_threadgroup);
2605
2606 // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2607 for (uint d = tid; d < head_dim; d += 256) {
2608 float acc = 0.0f;
2609 for (uint s = 0; s < tile_n; s++) {
2610 acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2611 }
2612 o_sh[d] += acc;
2613 }
2614 threadgroup_barrier(mem_flags::mem_threadgroup);
2615 }
2616
2617 // --- Normalize and write output ---
2618 float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2619 for (uint d = tid; d < head_dim; d += 256) {
2620 output_batch[out_off + d] = o_sh[d] * inv_l;
2621 }
2622}
2623
2624// ── attention_mma_flash_batch ─────────────────────────────────────────
2625// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2626// both Q·K^T and P·V. Processes Q_BLOCK=8 tokens of one head per
2627// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2628// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2629// arithmetic.
2630//
2631// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2632// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2633// to attention_batch / attention_flash_batch otherwise.
2634//
2635// Online softmax recurrence is identical to attention_flash_batch but
2636// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2637constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2638constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2639constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2640
2641kernel void attention_mma_flash_batch(
2642 device const float* q_batch [[buffer(0)]],
2643 device const half* k_cache [[buffer(1)]],
2644 device const half* v_cache [[buffer(2)]],
2645 device float* output_batch [[buffer(3)]],
2646 constant uint& M [[buffer(4)]],
2647 constant uint& base_pos [[buffer(5)]],
2648 constant uint& num_heads [[buffer(6)]],
2649 constant uint& num_kv_heads [[buffer(7)]],
2650 constant uint& head_dim [[buffer(8)]],
2651 constant uint& q_stride [[buffer(9)]],
2652 uint2 tgid [[threadgroup_position_in_grid]],
2653 uint tid [[thread_index_in_threadgroup]],
2654 uint simd_lane [[thread_index_in_simdgroup]],
2655 uint simd_id [[simdgroup_index_in_threadgroup]])
2656{
2657 uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2658 uint head = tgid.y;
2659 if (q_block_start >= M) return;
2660 uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2661
2662 uint kv_head = head / (num_heads / num_kv_heads);
2663 // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2664 // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2665 uint seq_len = base_pos + q_block_start + q_valid;
2666 float scale = fast::rsqrt(float(head_dim));
2667
2668 uint kv_stride = num_kv_heads * head_dim;
2669 uint kv_base_off = kv_head * head_dim;
2670
2671 // ── Threadgroup memory ──
2672 // q_sh: [Q_BLOCK, head_dim] half — Q tile, loaded once
2673 // k_sh: [K_BLOCK, head_dim] half — K tile, refreshed per kv_base iter
2674 // v_sh: [K_BLOCK, head_dim] half — V tile, refreshed per kv_base iter
2675 // s_sh: [Q_BLOCK, K_BLOCK] float — raw Q·K^T scores, then scaled+masked
2676 // p_sh: [Q_BLOCK, K_BLOCK] half — softmax probabilities (for P·V MMA)
2677 // o_sh: [Q_BLOCK, head_dim] float — running output accumulator
2678 // m_sh: [Q_BLOCK] float — running max per Q row
2679 // l_sh: [Q_BLOCK] float — running softmax denominator per Q row
2680 // scratch: 4*Q_BLOCK floats — per-row reduction scratch
2681 threadgroup half q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2682 threadgroup half k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2683 threadgroup half v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2684 threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2685 threadgroup half p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2686 threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2687 threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2688 threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2689 threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2690
2691 // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2692 uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2693 for (uint i = tid; i < qblock_elems; i += 128) {
2694 uint q = i / head_dim;
2695 uint d = i % head_dim;
2696 if (q < q_valid) {
2697 uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2698 q_sh[q * head_dim + d] = half(q_batch[q_off]);
2699 } else {
2700 q_sh[q * head_dim + d] = half(0);
2701 }
2702 o_sh[q * head_dim + d] = 0.0f;
2703 }
2704 if (tid < FLASH_MMA_Q_BLOCK) {
2705 m_sh[tid] = -INFINITY;
2706 l_sh[tid] = 0.0f;
2707 }
2708 threadgroup_barrier(mem_flags::mem_threadgroup);
2709
2710 // ── Stream K/V in K_BLOCK chunks ──
2711 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2712 uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2713
2714 // Load K and V tile into TG memory (as half).
2715 uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2716 for (uint i = tid; i < kv_tile_elems; i += 128) {
2717 uint k_pos = i / head_dim;
2718 uint d = i % head_dim;
2719 if (k_pos < tile_n) {
2720 uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2721 k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2722 v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2723 } else {
2724 k_sh[k_pos * head_dim + d] = half(0);
2725 v_sh[k_pos * head_dim + d] = half(0);
2726 }
2727 }
2728 threadgroup_barrier(mem_flags::mem_threadgroup);
2729
2730 // ── Phase 1: S = Q @ K^T via MMA ──
2731 // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2732 // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2733 // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2734 {
2735 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2736 uint dim_chunks = head_dim / 8;
2737 for (uint dc = 0; dc < dim_chunks; dc++) {
2738 simdgroup_matrix<half, 8, 8> A, B;
2739 // A = Q[0:8, dc*8 : dc*8+8] (rows of Q, no transpose)
2740 simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2741 // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2742 // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2743 // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2744 // transpose=true, which places it in the register as K^T of that sub-block.
2745 simdgroup_load(B,
2746 k_sh + (simd_id * 8) * head_dim + dc * 8,
2747 head_dim, ulong2(0, 0), true);
2748 simdgroup_multiply_accumulate(C, A, B, C);
2749 }
2750 // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2751 simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2752 }
2753 threadgroup_barrier(mem_flags::mem_threadgroup);
2754
2755 // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2756 // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2757 uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2758 for (uint i = tid; i < s_elems; i += 128) {
2759 uint q = i / FLASH_MMA_K_BLOCK;
2760 uint k = i % FLASH_MMA_K_BLOCK;
2761 uint global_q = q_block_start + q;
2762 uint global_kv = kv_base + k;
2763 bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2764 s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2765 }
2766 threadgroup_barrier(mem_flags::mem_threadgroup);
2767
2768 // ── Phase 2b: per-row max via simdgroup reduction ──
2769 // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2770 // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2771 {
2772 uint row_base = simd_id * 2;
2773 for (uint qr = 0; qr < 2; qr++) {
2774 uint q = row_base + qr;
2775 float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2776 float row_max = simd_max(my);
2777 if (simd_lane == 0) {
2778 scratch[q] = row_max; // tile_max[q]
2779 }
2780 }
2781 }
2782 threadgroup_barrier(mem_flags::mem_threadgroup);
2783
2784 // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2785 if (tid < FLASH_MMA_Q_BLOCK) {
2786 uint q = tid;
2787 float m_old = m_sh[q];
2788 float tile_max = scratch[q];
2789 float m_new = max(m_old, tile_max);
2790 float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2791 m_sh[q] = m_new;
2792 l_sh[q] = l_sh[q] * alpha;
2793 // scratch[q] = m_new (for phase 2d)
2794 // scratch[Q_BLOCK + q] = alpha (for phase 3)
2795 scratch[q] = m_new;
2796 scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2797 }
2798 threadgroup_barrier(mem_flags::mem_threadgroup);
2799
2800 // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2801 {
2802 uint row_base = simd_id * 2;
2803 for (uint qr = 0; qr < 2; qr++) {
2804 uint q = row_base + qr;
2805 float m_new = scratch[q];
2806 float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2807 p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2808 float row_sum = simd_sum(p);
2809 if (simd_lane == 0) {
2810 scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2811 }
2812 }
2813 }
2814 threadgroup_barrier(mem_flags::mem_threadgroup);
2815
2816 // ── Phase 2e: l_sh += tile_sum ──
2817 if (tid < FLASH_MMA_Q_BLOCK) {
2818 uint q = tid;
2819 l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2820 }
2821
2822 // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2823 threadgroup_barrier(mem_flags::mem_threadgroup);
2824 for (uint i = tid; i < qblock_elems; i += 128) {
2825 uint q = i / head_dim;
2826 float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2827 o_sh[i] *= alpha;
2828 }
2829 threadgroup_barrier(mem_flags::mem_threadgroup);
2830
2831 // ── Phase 4: O += P @ V via MMA ──
2832 // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2833 // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2834 // For head_dim=64: 16 dims/sg = 2 tiles. head_dim=128: 32 dims/sg = 4 tiles.
2835 {
2836 uint dims_per_sg = head_dim / 4; // 16 or 32
2837 uint tiles_per_sg = dims_per_sg / 8; // 2 or 4
2838 uint sg_d_base = simd_id * dims_per_sg;
2839 for (uint t = 0; t < tiles_per_sg; t++) {
2840 uint d_base = sg_d_base + t * 8;
2841 simdgroup_matrix<float, 8, 8> O_acc;
2842 simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2843 uint k_chunks = FLASH_MMA_K_BLOCK / 8; // 4
2844 for (uint kc = 0; kc < k_chunks; kc++) {
2845 simdgroup_matrix<half, 8, 8> A, B;
2846 simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2847 ulong2(0, 0), false);
2848 simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2849 ulong2(0, 0), false);
2850 simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2851 }
2852 simdgroup_store(O_acc, o_sh + d_base, head_dim);
2853 }
2854 }
2855 threadgroup_barrier(mem_flags::mem_threadgroup);
2856 }
2857
2858 // ── Finalize: O /= l, write to output ──
2859 for (uint i = tid; i < qblock_elems; i += 128) {
2860 uint q = i / head_dim;
2861 uint d = i % head_dim;
2862 if (q < q_valid) {
2863 float l = l_sh[q];
2864 float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2865 uint token_idx = q_block_start + q;
2866 uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2867 output_batch[out_off] = o_sh[i] * inv_l;
2868 }
2869 }
2870}
2871
2872// ── rope_qk_batch ─────────────────────────────────────────────────────
2873// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2874// launch + memory barrier per layer. Both Q and K live in the same
2875// qkv_data buffer at different offsets within each token's row.
2876// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2877kernel void rope_qk_batch(
2878 device float* qkv_data [[buffer(0)]], // [M, qkv_stride]
2879 constant uint& M [[buffer(1)]], // num tokens
2880 constant uint& base_pos [[buffer(2)]], // starting position
2881 constant uint& num_q_heads [[buffer(3)]],
2882 constant uint& num_kv_heads [[buffer(4)]],
2883 constant uint& head_dim [[buffer(5)]],
2884 constant uint& qkv_stride [[buffer(6)]], // floats per row
2885 constant float& theta [[buffer(7)]],
2886 uint id [[thread_position_in_grid]])
2887{
2888 uint half_dim = head_dim / 2;
2889 uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2890 uint token = id / total_pairs;
2891 uint pair = id % total_pairs;
2892 if (token >= M) return;
2893
2894 uint pos = base_pos + token;
2895 uint q_pairs = num_q_heads * half_dim;
2896
2897 uint h, i, offset;
2898 if (pair < q_pairs) {
2899 // Q head
2900 h = pair / half_dim;
2901 i = pair % half_dim;
2902 offset = token * qkv_stride + h * head_dim + i * 2;
2903 } else {
2904 // K head
2905 uint kp = pair - q_pairs;
2906 h = kp / half_dim;
2907 i = kp % half_dim;
2908 uint k_start = num_q_heads * head_dim;
2909 offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2910 }
2911
2912 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2913 float angle = float(pos) * freq;
2914 float cos_val = cos(angle);
2915 float sin_val = sin(angle);
2916
2917 float x0 = qkv_data[offset];
2918 float x1 = qkv_data[offset + 1];
2919 qkv_data[offset] = x0 * cos_val - x1 * sin_val;
2920 qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2921}
2922
2923// ── copy_kv_both_batch ────────────────────────────────────────────────
2924// Fused K+V cache copy in a single dispatch: copies both K and V from
2925// the strided batch QKV buffer to their respective KV cache buffers.
2926// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2927kernel void copy_kv_both_batch(
2928 device const float* src [[buffer(0)]], // batch QKV buffer [M, qkv_stride] f32
2929 device half* k_dst [[buffer(1)]], // K cache [max_seq, kv_dim] f16
2930 device half* v_dst [[buffer(2)]], // V cache [max_seq, kv_dim] f16
2931 constant uint& M [[buffer(3)]], // num tokens in batch
2932 constant uint& kv_dim [[buffer(4)]], // elements per KV vector
2933 constant uint& base_pos [[buffer(5)]], // starting position in cache
2934 constant uint& src_stride [[buffer(6)]], // floats per row in src (qkv_stride)
2935 constant uint& k_offset [[buffer(7)]], // float offset of K within each src row
2936 constant uint& v_offset [[buffer(8)]], // float offset of V within each src row
2937 uint id [[thread_position_in_grid]])
2938{
2939 // Total elements = M * kv_dim * 2 (K + V)
2940 uint total_kv = M * kv_dim;
2941 if (id >= total_kv * 2) return;
2942
2943 uint is_v = id / total_kv; // 0 = K, 1 = V
2944 uint local_id = id % total_kv;
2945 uint token = local_id / kv_dim;
2946 uint d = local_id % kv_dim;
2947
2948 uint dst_off = (base_pos + token) * kv_dim + d;
2949 uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2950
2951 if (is_v) {
2952 v_dst[dst_off] = half(src[src_off]);
2953 } else {
2954 k_dst[dst_off] = half(src[src_off]);
2955 }
2956}
2957"#
2958 .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2959 .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2960}
2961
2962fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2967 let mut code = String::with_capacity(48 * 1024);
2968 emit_model_header(&mut code, config)?;
2969 emit_metal_model_struct(&mut code, config)?;
2970 emit_layer_buffers_struct(&mut code, config)?;
2971 emit_metal_model_impl(&mut code, config)?;
2972 emit_helper_functions(&mut code)?;
2973 Ok(code)
2974}
2975
2976fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2977 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2978 writeln!(
2979 code,
2980 "//! Model: {} ({} layers, hidden={})",
2981 config.architecture, config.num_layers, config.hidden_size
2982 )?;
2983 writeln!(code, "//!")?;
2984 writeln!(
2985 code,
2986 "//! Uses native Metal compute pipelines via the metal crate."
2987 )?;
2988 writeln!(code)?;
2989 writeln!(code, "#![allow(dead_code)]")?;
2990 writeln!(code)?;
2991 writeln!(code, "use metal::*;")?;
2992 writeln!(code, "#[allow(unused_imports)]")?;
2993 writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2994 writeln!(code, "use std::mem;")?;
2995 writeln!(code)?;
2996
2997 writeln!(
2999 code,
3000 "// ── Model constants ──────────────────────────────────"
3001 )?;
3002 writeln!(
3003 code,
3004 "pub const HIDDEN_SIZE: usize = {};",
3005 config.hidden_size
3006 )?;
3007 writeln!(
3008 code,
3009 "pub const INTERMEDIATE_SIZE: usize = {};",
3010 config.intermediate_size
3011 )?;
3012 writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
3013 writeln!(
3014 code,
3015 "pub const NUM_HEADS: usize = {};",
3016 config.num_attention_heads
3017 )?;
3018 writeln!(
3019 code,
3020 "pub const NUM_KV_HEADS: usize = {};",
3021 config.num_kv_heads
3022 )?;
3023 writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
3024 writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
3025 let effective_seq_len = config.max_seq_len.min(4096);
3026 writeln!(
3027 code,
3028 "pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
3029 effective_seq_len, config.max_seq_len
3030 )?;
3031 writeln!(
3032 code,
3033 "pub const RMS_NORM_EPS: f32 = {:e};",
3034 config.rms_norm_eps
3035 )?;
3036 writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3037 writeln!(
3038 code,
3039 "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3040 )?;
3041 writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3042 writeln!(code)?;
3043
3044 Ok(())
3045}
3046
3047fn emit_metal_model_struct(
3048 code: &mut String,
3049 config: &ModelConfig,
3050) -> Result<(), MetalCodegenError> {
3051 writeln!(
3052 code,
3053 "// ── MetalModel ──────────────────────────────────────────"
3054 )?;
3055 writeln!(code)?;
3056 writeln!(
3057 code,
3058 "/// Metal-accelerated transformer model for Apple Silicon."
3059 )?;
3060 writeln!(code, "///")?;
3061 writeln!(
3062 code,
3063 "/// Uses unified memory for zero-copy weight access and native Metal"
3064 )?;
3065 writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3066 writeln!(code, "pub struct MetalModel {{")?;
3067 writeln!(code, " device: Device,")?;
3068 writeln!(code, " queue: CommandQueue,")?;
3069 writeln!(code)?;
3070 writeln!(code, " // ── Compute pipelines ──")?;
3071 writeln!(code, " matmul_pipeline: ComputePipelineState,")?;
3072 writeln!(code, " matmul_q8_pipeline: ComputePipelineState,")?;
3073 writeln!(code, " matmul_q4_pipeline: ComputePipelineState,")?;
3074 writeln!(code, " rms_norm_pipeline: ComputePipelineState,")?;
3075 writeln!(code, " rope_pipeline: ComputePipelineState,")?;
3076 writeln!(code, " softmax_pipeline: ComputePipelineState,")?;
3077 writeln!(code, " silu_mul_pipeline: ComputePipelineState,")?;
3078 writeln!(code, " silu_mul_fused_pipeline: ComputePipelineState,")?;
3079 writeln!(code, " add_pipeline: ComputePipelineState,")?;
3080 writeln!(code, " attention_pipeline: ComputePipelineState,")?;
3081 writeln!(code, " add_inplace_pipeline: ComputePipelineState,")?;
3082 writeln!(code, " copy_pipeline: ComputePipelineState,")?;
3083 writeln!(code, " copy_offset_pipeline: ComputePipelineState,")?;
3084 writeln!(
3085 code,
3086 " copy_f32_to_f16_offset_pipeline: ComputePipelineState,"
3087 )?;
3088 writeln!(code)?;
3089 writeln!(code, " // ── Batched prefill pipelines ──")?;
3090 writeln!(code, " matmul_batch_pipeline: ComputePipelineState,")?;
3091 writeln!(code, " matmul_q8_batch_pipeline: ComputePipelineState,")?;
3092 writeln!(
3093 code,
3094 " matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3095 )?;
3096 writeln!(code, " matmul_q8_mma_pipeline: ComputePipelineState,")?;
3097 writeln!(code, " matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3098 writeln!(
3099 code,
3100 " matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3101 )?;
3102 writeln!(
3103 code,
3104 " matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3105 )?;
3106 writeln!(
3107 code,
3108 " matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3109 )?;
3110 if config.qkv_bias {
3111 writeln!(code, " add_bias_batch_pipeline: ComputePipelineState,")?;
3112 }
3113 writeln!(code, " matmul_q4_batch_pipeline: ComputePipelineState,")?;
3114 writeln!(code, " rms_norm_batch_pipeline: ComputePipelineState,")?;
3115 writeln!(code, " rope_batch_pipeline: ComputePipelineState,")?;
3116 writeln!(
3117 code,
3118 " silu_mul_fused_batch_pipeline: ComputePipelineState,"
3119 )?;
3120 writeln!(
3121 code,
3122 " add_inplace_batch_pipeline: ComputePipelineState,"
3123 )?;
3124 writeln!(
3125 code,
3126 " copy_embedding_batch_pipeline: ComputePipelineState,"
3127 )?;
3128 writeln!(code, " attention_batch_pipeline: ComputePipelineState,")?;
3129 writeln!(
3130 code,
3131 " attention_flash_batch_pipeline: ComputePipelineState,"
3132 )?;
3133 writeln!(
3134 code,
3135 " attention_mma_flash_batch_pipeline: ComputePipelineState,"
3136 )?;
3137 writeln!(code, " copy_kv_batch_pipeline: ComputePipelineState,")?;
3138 writeln!(code, " rope_qk_batch_pipeline: ComputePipelineState,")?;
3139 writeln!(
3140 code,
3141 " copy_kv_both_batch_pipeline: ComputePipelineState,"
3142 )?;
3143 writeln!(code)?;
3144 writeln!(code, " // ── Weight buffers (Metal shared memory) ──")?;
3145 writeln!(
3146 code,
3147 " /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3148 )?;
3149 writeln!(code, " embed_buf: Buffer,")?;
3150 writeln!(code)?;
3151 writeln!(code, " /// Per-layer weight buffers")?;
3152 writeln!(code, " layers: Vec<LayerBuffers>,")?;
3153 writeln!(code)?;
3154 writeln!(code, " /// Final layer-norm weight [HIDDEN_SIZE]")?;
3155 writeln!(code, " norm_buf: Buffer,")?;
3156 writeln!(code)?;
3157 writeln!(
3158 code,
3159 " /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3160 )?;
3161 writeln!(code, " lm_head_buf: Buffer,")?;
3162 writeln!(code)?;
3163 writeln!(
3164 code,
3165 " // ── Working buffers (pre-allocated, reused every forward pass) ──"
3166 )?;
3167 writeln!(code, " hidden_buf: Buffer,")?;
3168 writeln!(code, " residual_buf: Buffer,")?;
3169 writeln!(code, " normed_buf: Buffer,")?;
3170 writeln!(
3171 code,
3172 " /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3173 )?;
3174 writeln!(code, " qkv_buf: Buffer,")?;
3175 writeln!(code, " attn_out_buf: Buffer,")?;
3176 writeln!(code, " attn_proj_buf: Buffer,")?;
3177 writeln!(
3178 code,
3179 " /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3180 )?;
3181 writeln!(code, " gate_up_buf: Buffer,")?;
3182 writeln!(code, " ffn_hidden_buf: Buffer,")?;
3183 writeln!(code, " ffn_out_buf: Buffer,")?;
3184 writeln!(code, " add_tmp_buf: Buffer,")?;
3185 writeln!(code, " logits_buf: Buffer,")?;
3186 writeln!(code)?;
3187 writeln!(code, " // ── Batched prefill working buffers ──")?;
3188 writeln!(code, " /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3189 writeln!(code, " batch_hidden_buf: Buffer,")?;
3190 writeln!(
3191 code,
3192 " /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3193 )?;
3194 writeln!(code, " batch_residual_buf: Buffer,")?;
3195 writeln!(code, " /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3196 writeln!(code, " batch_qkv_buf: Buffer,")?;
3197 writeln!(
3198 code,
3199 " /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3200 )?;
3201 writeln!(code, " batch_attn_out_buf: Buffer,")?;
3202 writeln!(
3203 code,
3204 " /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3205 )?;
3206 writeln!(code, " batch_attn_proj_buf: Buffer,")?;
3207 writeln!(
3208 code,
3209 " /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3210 )?;
3211 writeln!(code, " batch_gate_up_buf: Buffer,")?;
3212 writeln!(
3213 code,
3214 " /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3215 )?;
3216 writeln!(code, " batch_ffn_hidden_buf: Buffer,")?;
3217 writeln!(code, " /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3218 writeln!(code, " batch_ffn_out_buf: Buffer,")?;
3219 writeln!(code, " /// Token IDs buffer for batch embedding lookup")?;
3220 writeln!(code, " batch_tokens_buf: Buffer,")?;
3221 writeln!(code, " /// Positions buffer for batch RoPE")?;
3222 writeln!(code, " batch_positions_buf: Buffer,")?;
3223 writeln!(code)?;
3224 writeln!(code, " // ── KV cache buffers (per-layer) ──")?;
3225 writeln!(code, " k_cache: Vec<Buffer>, // per-layer")?;
3226 writeln!(code, " v_cache: Vec<Buffer>, // per-layer")?;
3227 writeln!(code)?;
3228 writeln!(code, " // ── Inference state ──")?;
3229 writeln!(code, " pos: usize,")?;
3230 writeln!(code)?;
3231 writeln!(
3232 code,
3233 " /// Previous command buffer for double-buffered prefill."
3234 )?;
3235 writeln!(
3236 code,
3237 " /// While the GPU executes token N, the CPU can encode token N+1."
3238 )?;
3239 writeln!(code, " prev_cmd: Option<CommandBuffer>,")?;
3240 writeln!(code, "}}")?;
3241 writeln!(code)?;
3242
3243 Ok(())
3244}
3245
3246fn emit_layer_buffers_struct(
3247 code: &mut String,
3248 config: &ModelConfig,
3249) -> Result<(), MetalCodegenError> {
3250 writeln!(
3251 code,
3252 "/// Per-layer weight buffers for attention and FFN projections."
3253 )?;
3254 writeln!(code, "struct LayerBuffers {{")?;
3255 writeln!(code, " attn_norm: Buffer,")?;
3256 writeln!(
3257 code,
3258 " /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3259 )?;
3260 writeln!(code, " qkv_weight: Buffer,")?;
3261 if config.qkv_bias {
3262 writeln!(
3263 code,
3264 " /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3265 )?;
3266 writeln!(code, " qkv_bias: Buffer,")?;
3267 }
3268 writeln!(code, " o_weight: Buffer,")?;
3269 writeln!(code, " ffn_norm: Buffer,")?;
3270 writeln!(
3271 code,
3272 " /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3273 )?;
3274 writeln!(code, " gate_up_weight: Buffer,")?;
3275 writeln!(code, " down_weight: Buffer,")?;
3276 writeln!(code, "}}")?;
3277 writeln!(code)?;
3278
3279 Ok(())
3280}
3281
3282fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3283 let hidden = config.hidden_size;
3284 let intermediate = config.intermediate_size;
3285 let _num_layers = config.num_layers;
3286 let _num_heads = config.num_attention_heads;
3287 let num_kv_heads = config.num_kv_heads;
3288 let head_dim = config.head_dim;
3289 let vocab = config.vocab_size;
3290 let effective_seq_len = config.max_seq_len.min(4096);
3291 let is_q8 = config.dtype == DType::Q8_0;
3292 let is_q4 = config.dtype == DType::Q4_0;
3293 let kv_dim = num_kv_heads * head_dim;
3294
3295 writeln!(code, "impl MetalModel {{")?;
3296
3297 writeln!(
3299 code,
3300 " /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3301 )?;
3302 writeln!(code, " ///")?;
3303 writeln!(
3304 code,
3305 " /// `weights` is the raw weight blob produced by `forge export-weights`."
3306 )?;
3307 writeln!(code, " pub fn new(weights: &[u8]) -> Self {{")?;
3308 writeln!(
3309 code,
3310 " let device = Device::system_default().expect(\"no Metal device found\");"
3311 )?;
3312 writeln!(code, " let queue = device.new_command_queue();")?;
3313 writeln!(code)?;
3314
3315 writeln!(
3317 code,
3318 " // Compile Metal shaders from embedded source"
3319 )?;
3320 writeln!(
3321 code,
3322 " let shader_source = include_str!(\"../shaders/kernels.metal\");"
3323 )?;
3324 writeln!(
3325 code,
3326 " let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3327 )?;
3328 writeln!(
3329 code,
3330 " .expect(\"failed to compile Metal shaders\");"
3331 )?;
3332 writeln!(code)?;
3333
3334 writeln!(code, " // Create compute pipelines")?;
3336 for (var, fn_name) in [
3337 ("matmul_pipeline", "matmul_vec"),
3338 ("matmul_q8_pipeline", "matmul_vec_q8"),
3339 ("matmul_q4_pipeline", "matmul_vec_q4"),
3340 ("rms_norm_pipeline", "rms_norm"),
3341 ("rope_pipeline", "rope"),
3342 ("softmax_pipeline", "softmax"),
3343 ("silu_mul_pipeline", "silu_mul"),
3344 ("silu_mul_fused_pipeline", "silu_mul_fused"),
3345 ("add_pipeline", "elementwise_add"),
3346 ("attention_pipeline", "attention"),
3347 ("add_inplace_pipeline", "add_inplace"),
3348 ("copy_pipeline", "copy_buffer"),
3349 ("copy_offset_pipeline", "copy_offset"),
3350 ("copy_f32_to_f16_offset_pipeline", "copy_f32_to_f16_offset"),
3351 ("matmul_batch_pipeline", "matmul_vec_batch"),
3352 ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3353 ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3354 ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3355 ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3356 ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3357 ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3358 ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3359 ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3360 ("rms_norm_batch_pipeline", "rms_norm_batch"),
3361 ("rope_batch_pipeline", "rope_batch"),
3362 ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3363 ("add_inplace_batch_pipeline", "add_inplace_batch"),
3364 ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3365 ("attention_batch_pipeline", "attention_batch"),
3366 ("attention_flash_batch_pipeline", "attention_flash_batch"),
3367 (
3368 "attention_mma_flash_batch_pipeline",
3369 "attention_mma_flash_batch",
3370 ),
3371 ("copy_kv_batch_pipeline", "copy_kv_batch"),
3372 ("rope_qk_batch_pipeline", "rope_qk_batch"),
3373 ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3374 ] {
3375 writeln!(
3376 code,
3377 " let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3378 )?;
3379 }
3380 if config.qkv_bias {
3381 writeln!(
3382 code,
3383 " let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3384 )?;
3385 }
3386 writeln!(code)?;
3387
3388 writeln!(
3390 code,
3391 " // Load weights into Metal shared-memory buffers"
3392 )?;
3393 writeln!(code, " let f32_size = mem::size_of::<f32>();")?;
3394 writeln!(code, " let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3395 writeln!(code, " let hidden_elems = HIDDEN_SIZE;")?;
3396 writeln!(code)?;
3397 writeln!(
3398 code,
3399 " let cursor = std::cell::Cell::new(0usize); // byte cursor into `weights`"
3400 )?;
3401 writeln!(code)?;
3402 writeln!(
3403 code,
3404 " // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3405 )?;
3406 writeln!(
3407 code,
3408 " let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3409 )?;
3410 writeln!(code, " let byte_len = n * f32_size;")?;
3411 writeln!(code, " let cur = cursor.get();")?;
3412 writeln!(
3413 code,
3414 " let data = &weights[cur..cur + byte_len];"
3415 )?;
3416 writeln!(code, " cursor.set(cur + byte_len);")?;
3417 writeln!(code, " device.new_buffer_with_data(")?;
3418 writeln!(code, " data.as_ptr() as *const _,")?;
3419 writeln!(code, " byte_len as u64,")?;
3420 writeln!(
3421 code,
3422 " MTLResourceOptions::StorageModeShared,"
3423 )?;
3424 writeln!(code, " )")?;
3425 writeln!(code, " }};")?;
3426 writeln!(code)?;
3427
3428 if is_q8 {
3429 writeln!(
3434 code,
3435 " // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3436 )?;
3437 writeln!(
3438 code,
3439 " // as raw bytes into a Metal buffer (no dequantization)."
3440 )?;
3441 writeln!(
3442 code,
3443 " // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3444 )?;
3445 writeln!(
3446 code,
3447 " let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3448 )?;
3449 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3450 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3451 writeln!(code, " let total_raw = rows * row_bytes;")?;
3452 writeln!(code, " let cur = cursor.get();")?;
3453 writeln!(
3454 code,
3455 " let data = &weights[cur..cur + total_raw];"
3456 )?;
3457 writeln!(code, " cursor.set(cur + total_raw);")?;
3458 writeln!(code, " device.new_buffer_with_data(")?;
3459 writeln!(code, " data.as_ptr() as *const _,")?;
3460 writeln!(code, " total_raw as u64,")?;
3461 writeln!(
3462 code,
3463 " MTLResourceOptions::StorageModeShared,"
3464 )?;
3465 writeln!(code, " )")?;
3466 writeln!(code, " }};")?;
3467 writeln!(code)?;
3468 writeln!(
3469 code,
3470 " // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3471 )?;
3472 writeln!(
3473 code,
3474 " // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3475 )?;
3476 writeln!(
3477 code,
3478 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3479 )?;
3480 writeln!(
3481 code,
3482 " let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3483 )?;
3484 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3485 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3486 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3487 writeln!(code, " let cur = cursor.get();")?;
3488 writeln!(
3489 code,
3490 " let data = &weights[cur..cur + total_raw];"
3491 )?;
3492 writeln!(code, " cursor.set(cur + total_raw);")?;
3493 writeln!(code, " device.new_buffer_with_data(")?;
3494 writeln!(code, " data.as_ptr() as *const _,")?;
3495 writeln!(code, " total_raw as u64,")?;
3496 writeln!(
3497 code,
3498 " MTLResourceOptions::StorageModeShared,"
3499 )?;
3500 writeln!(code, " )")?;
3501 writeln!(code, " }};")?;
3502 writeln!(code)?;
3503 }
3504
3505 if is_q4 {
3506 writeln!(
3511 code,
3512 " // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3513 )?;
3514 writeln!(
3515 code,
3516 " // as raw bytes into a Metal buffer (no dequantization)."
3517 )?;
3518 writeln!(
3519 code,
3520 " // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3521 )?;
3522 writeln!(
3523 code,
3524 " let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3525 )?;
3526 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3527 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3528 writeln!(code, " let total_raw = rows * row_bytes;")?;
3529 writeln!(code, " let cur = cursor.get();")?;
3530 writeln!(
3531 code,
3532 " let data = &weights[cur..cur + total_raw];"
3533 )?;
3534 writeln!(code, " cursor.set(cur + total_raw);")?;
3535 writeln!(code, " device.new_buffer_with_data(")?;
3536 writeln!(code, " data.as_ptr() as *const _,")?;
3537 writeln!(code, " total_raw as u64,")?;
3538 writeln!(
3539 code,
3540 " MTLResourceOptions::StorageModeShared,"
3541 )?;
3542 writeln!(code, " )")?;
3543 writeln!(code, " }};")?;
3544 writeln!(code)?;
3545 writeln!(
3546 code,
3547 " // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3548 )?;
3549 writeln!(
3550 code,
3551 " // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3552 )?;
3553 writeln!(
3554 code,
3555 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3556 )?;
3557 writeln!(
3558 code,
3559 " let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3560 )?;
3561 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3562 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3563 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3564 writeln!(code, " let cur = cursor.get();")?;
3565 writeln!(
3566 code,
3567 " let data = &weights[cur..cur + total_raw];"
3568 )?;
3569 writeln!(code, " cursor.set(cur + total_raw);")?;
3570 writeln!(code, " device.new_buffer_with_data(")?;
3571 writeln!(code, " data.as_ptr() as *const _,")?;
3572 writeln!(code, " total_raw as u64,")?;
3573 writeln!(
3574 code,
3575 " MTLResourceOptions::StorageModeShared,"
3576 )?;
3577 writeln!(code, " )")?;
3578 writeln!(code, " }};")?;
3579 writeln!(code)?;
3580 }
3581
3582 writeln!(
3583 code,
3584 " let embed_buf = next_f32_buffer(&device, embed_elems);"
3585 )?;
3586 writeln!(code)?;
3587
3588 writeln!(
3590 code,
3591 " let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3592 )?;
3593 writeln!(code, " for _layer in 0..NUM_LAYERS {{")?;
3594
3595 writeln!(
3597 code,
3598 " let attn_norm = next_f32_buffer(&device, hidden_elems);"
3599 )?;
3600
3601 let qkv_rows = hidden + 2 * kv_dim;
3602 if is_q8 {
3603 writeln!(
3605 code,
3606 " let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3607 )?;
3608 if config.qkv_bias {
3609 writeln!(
3610 code,
3611 " // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3612 )?;
3613 writeln!(
3614 code,
3615 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3616 )?;
3617 }
3618 writeln!(
3619 code,
3620 " let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3621 )?;
3622 } else if is_q4 {
3623 writeln!(
3624 code,
3625 " let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3626 )?;
3627 if config.qkv_bias {
3628 writeln!(
3629 code,
3630 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3631 )?;
3632 }
3633 writeln!(
3634 code,
3635 " let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3636 )?;
3637 } else {
3638 writeln!(
3639 code,
3640 " let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3641 )?;
3642 if config.qkv_bias {
3643 writeln!(
3644 code,
3645 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3646 )?;
3647 }
3648 writeln!(
3649 code,
3650 " let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3651 )?;
3652 }
3653
3654 writeln!(
3656 code,
3657 " let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3658 )?;
3659
3660 let gate_up_rows = 2 * intermediate;
3661 if is_q8 {
3662 writeln!(
3664 code,
3665 " let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3666 )?;
3667 writeln!(
3668 code,
3669 " let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3670 )?;
3671 } else if is_q4 {
3672 writeln!(
3674 code,
3675 " let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3676 )?;
3677 writeln!(
3678 code,
3679 " let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3680 )?;
3681 } else {
3682 writeln!(
3684 code,
3685 " let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3686 )?;
3687 writeln!(
3688 code,
3689 " let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3690 )?;
3691 }
3692
3693 writeln!(code, " layers.push(LayerBuffers {{")?;
3694 writeln!(code, " attn_norm,")?;
3695 writeln!(code, " qkv_weight,")?;
3696 if config.qkv_bias {
3697 writeln!(code, " qkv_bias,")?;
3698 }
3699 writeln!(code, " o_weight,")?;
3700 writeln!(code, " ffn_norm,")?;
3701 writeln!(code, " gate_up_weight,")?;
3702 writeln!(code, " down_weight,")?;
3703 writeln!(code, " }});")?;
3704 writeln!(code, " }}")?;
3705 writeln!(code)?;
3706
3707 writeln!(
3709 code,
3710 " let norm_buf = next_f32_buffer(&device, hidden_elems);"
3711 )?;
3712 writeln!(code)?;
3713
3714 if is_q8 {
3716 writeln!(
3717 code,
3718 " let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3719 )?;
3720 } else if is_q4 {
3721 writeln!(
3722 code,
3723 " let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3724 )?;
3725 } else {
3726 writeln!(
3727 code,
3728 " let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3729 )?;
3730 }
3731 writeln!(code)?;
3732
3733 let hidden_bytes = hidden * 4;
3735 let _kv_dim_bytes = kv_dim * 4;
3736 let intermediate_bytes = intermediate * 4;
3737 let vocab_bytes = vocab * 4;
3738 let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 2;
3741
3742 writeln!(
3743 code,
3744 " // Allocate working buffers (shared memory for zero-copy)"
3745 )?;
3746 writeln!(
3747 code,
3748 " let opts = MTLResourceOptions::StorageModeShared;"
3749 )?;
3750 writeln!(
3751 code,
3752 " let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3753 )?;
3754 writeln!(
3755 code,
3756 " let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3757 )?;
3758 let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3759 writeln!(
3760 code,
3761 " let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3762 )?;
3763 writeln!(
3764 code,
3765 " // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3766 )?;
3767 writeln!(
3768 code,
3769 " let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3770 )?;
3771 writeln!(
3772 code,
3773 " let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3774 )?;
3775 writeln!(
3776 code,
3777 " let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3778 )?;
3779 let gate_up_buf_bytes = 2 * intermediate * 4;
3780 writeln!(
3781 code,
3782 " // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3783 )?;
3784 writeln!(
3785 code,
3786 " let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3787 )?;
3788 writeln!(
3789 code,
3790 " let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3791 )?;
3792 writeln!(
3793 code,
3794 " let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3795 )?;
3796 writeln!(
3797 code,
3798 " let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3799 )?;
3800 writeln!(
3801 code,
3802 " let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3803 )?;
3804 writeln!(code)?;
3805
3806 let batch_hidden_bytes = hidden * 4; let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3809 let batch_gate_up_bytes = 2 * intermediate * 4;
3810 let batch_intermediate_bytes = intermediate * 4;
3811 writeln!(
3812 code,
3813 " // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3814 )?;
3815 writeln!(
3816 code,
3817 " let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3818 )?;
3819 writeln!(
3820 code,
3821 " let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3822 )?;
3823 writeln!(
3824 code,
3825 " let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3826 )?;
3827 writeln!(
3828 code,
3829 " let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3830 )?;
3831 writeln!(
3832 code,
3833 " let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3834 )?;
3835 writeln!(
3836 code,
3837 " let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3838 )?;
3839 writeln!(
3840 code,
3841 " let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3842 )?;
3843 writeln!(
3844 code,
3845 " let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3846 )?;
3847 writeln!(
3848 code,
3849 " let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3850 )?;
3851 writeln!(
3852 code,
3853 " let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3854 )?;
3855 writeln!(code)?;
3856
3857 writeln!(code, " // KV cache buffers (per-layer)")?;
3859 writeln!(
3860 code,
3861 " let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3862 )?;
3863 writeln!(
3864 code,
3865 " let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3866 )?;
3867 writeln!(code, " for _ in 0..NUM_LAYERS {{")?;
3868 writeln!(
3869 code,
3870 " k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3871 )?;
3872 writeln!(
3873 code,
3874 " v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3875 )?;
3876 writeln!(code, " }}")?;
3877 writeln!(code)?;
3878
3879 writeln!(code, " Self {{")?;
3880 writeln!(code, " device,")?;
3881 writeln!(code, " queue,")?;
3882 writeln!(code, " matmul_pipeline,")?;
3883 writeln!(code, " matmul_q8_pipeline,")?;
3884 writeln!(code, " matmul_q4_pipeline,")?;
3885 writeln!(code, " rms_norm_pipeline,")?;
3886 writeln!(code, " rope_pipeline,")?;
3887 writeln!(code, " softmax_pipeline,")?;
3888 writeln!(code, " silu_mul_pipeline,")?;
3889 writeln!(code, " silu_mul_fused_pipeline,")?;
3890 writeln!(code, " add_pipeline,")?;
3891 writeln!(code, " attention_pipeline,")?;
3892 writeln!(code, " add_inplace_pipeline,")?;
3893 writeln!(code, " copy_pipeline,")?;
3894 writeln!(code, " copy_offset_pipeline,")?;
3895 writeln!(code, " copy_f32_to_f16_offset_pipeline,")?;
3896 writeln!(code, " matmul_batch_pipeline,")?;
3897 writeln!(code, " matmul_q8_batch_pipeline,")?;
3898 writeln!(code, " matmul_q8_gemm_batch_pipeline,")?;
3899 writeln!(code, " matmul_q8_mma_pipeline,")?;
3900 writeln!(code, " matmul_q8_mma32_pipeline,")?;
3901 writeln!(code, " matmul_q8_mma32_h_pipeline,")?;
3902 writeln!(code, " matmul_q8_mma32_h4_pipeline,")?;
3903 writeln!(code, " matmul_q8_mma32_hh4_pipeline,")?;
3904 if config.qkv_bias {
3905 writeln!(code, " add_bias_batch_pipeline,")?;
3906 }
3907 writeln!(code, " matmul_q4_batch_pipeline,")?;
3908 writeln!(code, " rms_norm_batch_pipeline,")?;
3909 writeln!(code, " rope_batch_pipeline,")?;
3910 writeln!(code, " silu_mul_fused_batch_pipeline,")?;
3911 writeln!(code, " add_inplace_batch_pipeline,")?;
3912 writeln!(code, " copy_embedding_batch_pipeline,")?;
3913 writeln!(code, " attention_batch_pipeline,")?;
3914 writeln!(code, " attention_flash_batch_pipeline,")?;
3915 writeln!(code, " attention_mma_flash_batch_pipeline,")?;
3916 writeln!(code, " copy_kv_batch_pipeline,")?;
3917 writeln!(code, " rope_qk_batch_pipeline,")?;
3918 writeln!(code, " copy_kv_both_batch_pipeline,")?;
3919 writeln!(code, " embed_buf,")?;
3920 writeln!(code, " layers,")?;
3921 writeln!(code, " norm_buf,")?;
3922 writeln!(code, " lm_head_buf,")?;
3923 writeln!(code, " hidden_buf,")?;
3924 writeln!(code, " residual_buf,")?;
3925 writeln!(code, " normed_buf,")?;
3926 writeln!(code, " qkv_buf,")?;
3927 writeln!(code, " attn_out_buf,")?;
3928 writeln!(code, " attn_proj_buf,")?;
3929 writeln!(code, " gate_up_buf,")?;
3930 writeln!(code, " ffn_hidden_buf,")?;
3931 writeln!(code, " ffn_out_buf,")?;
3932 writeln!(code, " add_tmp_buf,")?;
3933 writeln!(code, " logits_buf,")?;
3934 writeln!(code, " batch_hidden_buf,")?;
3935 writeln!(code, " batch_residual_buf,")?;
3936 writeln!(code, " batch_qkv_buf,")?;
3937 writeln!(code, " batch_attn_out_buf,")?;
3938 writeln!(code, " batch_attn_proj_buf,")?;
3939 writeln!(code, " batch_gate_up_buf,")?;
3940 writeln!(code, " batch_ffn_hidden_buf,")?;
3941 writeln!(code, " batch_ffn_out_buf,")?;
3942 writeln!(code, " batch_tokens_buf,")?;
3943 writeln!(code, " batch_positions_buf,")?;
3944 writeln!(code, " k_cache,")?;
3945 writeln!(code, " v_cache,")?;
3946 writeln!(code, " pos: 0,")?;
3947 writeln!(code, " prev_cmd: None,")?;
3948 writeln!(code, " }}")?;
3949 writeln!(code, " }}")?;
3950 writeln!(code)?;
3951
3952 writeln!(
3954 code,
3955 " /// Run the forward pass for a single token at the current position."
3956 )?;
3957 writeln!(code, " ///")?;
3958 writeln!(
3959 code,
3960 " /// Returns logits over the vocabulary as a `Vec<f32>`."
3961 )?;
3962 writeln!(code, " ///")?;
3963 writeln!(
3964 code,
3965 " /// All GPU operations are encoded into a single command buffer and"
3966 )?;
3967 writeln!(
3968 code,
3969 " /// committed once at the end, avoiding per-operation synchronization."
3970 )?;
3971 writeln!(
3972 code,
3973 " pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3974 )?;
3975 writeln!(
3976 code,
3977 " // Wait for any pending prefill command buffer"
3978 )?;
3979 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
3980 writeln!(code, " prev.wait_until_completed();")?;
3981 writeln!(code, " }}")?;
3982 writeln!(code)?;
3983 writeln!(code, " let pos = self.pos;")?;
3984 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
3985 writeln!(code)?;
3986
3987 let matmul_fn = if is_q8 {
3990 "dispatch_matmul_q8"
3991 } else if is_q4 {
3992 "dispatch_matmul_q4"
3993 } else {
3994 "dispatch_matmul"
3995 };
3996
3997 writeln!(
3998 code,
3999 " // Single compute encoder for the entire forward pass (no blit transitions)"
4000 )?;
4001 writeln!(code, " {{")?;
4002 writeln!(
4003 code,
4004 " let enc = cmd.new_compute_command_encoder();"
4005 )?;
4006 writeln!(code)?;
4007
4008 writeln!(
4010 code,
4011 " // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
4012 )?;
4013 writeln!(
4014 code,
4015 " // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
4016 )?;
4017 writeln!(
4018 code,
4019 " // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
4020 )?;
4021 writeln!(
4022 code,
4023 " // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
4024 hidden * 4,
4025 )?;
4026 writeln!(code, " unsafe {{")?;
4027 writeln!(
4028 code,
4029 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4030 )?;
4031 writeln!(
4032 code,
4033 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4034 )?;
4035 writeln!(
4036 code,
4037 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4038 )?;
4039 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4040 writeln!(
4041 code,
4042 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4043 )?;
4044 writeln!(code, " hidden_ptr,")?;
4045 writeln!(code, " HIDDEN_SIZE,")?;
4046 writeln!(code, " );")?;
4047 writeln!(
4048 code,
4049 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4050 )?;
4051 writeln!(code, " }}")?;
4052 writeln!(code)?;
4053
4054 writeln!(code, " // 2. Transformer layers")?;
4056 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4057 writeln!(code)?;
4058 let q_byte_offset = 0usize;
4059 let k_float_offset = hidden;
4060 let v_float_offset = hidden + kv_dim;
4061
4062 writeln!(
4063 code,
4064 " // Pre-attention: rms_norm, fused QKV projection, RoPE"
4065 )?;
4066 writeln!(
4067 code,
4068 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4069 )?;
4070 writeln!(
4071 code,
4072 " // Fused Q+K+V matmul: single dispatch for all three projections"
4073 )?;
4074 writeln!(
4075 code,
4076 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4077 )?;
4078 if config.qkv_bias {
4079 writeln!(
4080 code,
4081 " // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4082 )?;
4083 writeln!(
4084 code,
4085 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4086 )?;
4087 }
4088 writeln!(
4089 code,
4090 " // Fused Q+K RoPE in one dispatch (saves 1 dispatch + barrier vs separate Q and K rope)"
4091 )?;
4092 writeln!(
4093 code,
4094 " self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4095 )?;
4096 writeln!(code)?;
4097 writeln!(
4098 code,
4099 " // Fused K+V cache update in one dispatch (f32 qkv_buf -> f16 KV cache)"
4100 )?;
4101 writeln!(
4102 code,
4103 " self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4104 )?;
4105 writeln!(code)?;
4106 writeln!(
4107 code,
4108 " // Attention using Q from qkv_buf (offset 0)"
4109 )?;
4110 writeln!(
4111 code,
4112 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4113 )?;
4114 writeln!(
4115 code,
4116 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4117 )?;
4118 writeln!(
4119 code,
4120 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4121 )?;
4122 writeln!(
4123 code,
4124 " // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4125 )?;
4126 writeln!(
4127 code,
4128 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4129 )?;
4130 writeln!(
4131 code,
4132 " // Fused gate+up matmul: single dispatch for both projections"
4133 )?;
4134 writeln!(
4135 code,
4136 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4137 )?;
4138 writeln!(
4139 code,
4140 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4141 )?;
4142 writeln!(
4143 code,
4144 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4145 )?;
4146 writeln!(
4147 code,
4148 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4149 )?;
4150 writeln!(code, " }}")?;
4151 writeln!(code)?;
4152
4153 writeln!(code, " // 3. Final RMS norm + logits projection")?;
4155 writeln!(
4156 code,
4157 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4158 )?;
4159 writeln!(
4160 code,
4161 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4162 )?;
4163 writeln!(code)?;
4164 writeln!(code, " enc.end_encoding();")?;
4165 writeln!(code, " }}")?;
4166 writeln!(code)?;
4167
4168 writeln!(
4170 code,
4171 " // 5. Commit all GPU work and wait for completion"
4172 )?;
4173 writeln!(code, " cmd.commit();")?;
4174 writeln!(code, " cmd.wait_until_completed();")?;
4175 writeln!(code)?;
4176 writeln!(code, " // 6. Read back logits from GPU")?;
4177 writeln!(code, " let logits = unsafe {{")?;
4178 writeln!(
4179 code,
4180 " let ptr = self.logits_buf.contents() as *const f32;"
4181 )?;
4182 writeln!(
4183 code,
4184 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4185 )?;
4186 writeln!(code, " }};")?;
4187 writeln!(code)?;
4188 writeln!(code, " self.pos += 1;")?;
4189 writeln!(code, " logits")?;
4190 writeln!(code, " }}")?;
4191 writeln!(code)?;
4192
4193 writeln!(
4195 code,
4196 " /// Profiling forward pass that prints per-stage GPU timing."
4197 )?;
4198 writeln!(code, " ///")?;
4199 writeln!(
4200 code,
4201 " /// Each stage is committed and waited on separately so that GPU timestamps"
4202 )?;
4203 writeln!(
4204 code,
4205 " /// accurately reflect per-operation cost. This is slower than `forward()` due"
4206 )?;
4207 writeln!(
4208 code,
4209 " /// to the per-stage synchronization, but useful for identifying bottlenecks."
4210 )?;
4211 writeln!(
4212 code,
4213 " pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4214 )?;
4215 writeln!(code, " use std::time::Instant;")?;
4216 writeln!(code)?;
4217 writeln!(
4218 code,
4219 " // Wait for any pending prefill command buffer"
4220 )?;
4221 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4222 writeln!(code, " prev.wait_until_completed();")?;
4223 writeln!(code, " }}")?;
4224 writeln!(code)?;
4225 writeln!(code, " let pos = self.pos;")?;
4226 writeln!(code)?;
4227
4228 writeln!(
4230 code,
4231 " // ── Stage: Embedding lookup (CPU via unified memory) ──"
4232 )?;
4233 writeln!(code, " let t_embed = Instant::now();")?;
4234 writeln!(code, " unsafe {{")?;
4235 writeln!(
4236 code,
4237 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4238 )?;
4239 writeln!(
4240 code,
4241 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4242 )?;
4243 writeln!(
4244 code,
4245 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4246 )?;
4247 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4248 writeln!(
4249 code,
4250 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4251 )?;
4252 writeln!(code, " hidden_ptr,")?;
4253 writeln!(code, " HIDDEN_SIZE,")?;
4254 writeln!(code, " );")?;
4255 writeln!(
4256 code,
4257 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4258 )?;
4259 writeln!(code, " }}")?;
4260 writeln!(code, " let d_embed = t_embed.elapsed();")?;
4261 writeln!(code)?;
4262
4263 writeln!(code, " // ── Stage: Transformer layers (GPU) ──")?;
4265 writeln!(code, " let t_layers = Instant::now();")?;
4266 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4267 writeln!(code, " {{")?;
4268 writeln!(
4269 code,
4270 " let enc = cmd.new_compute_command_encoder();"
4271 )?;
4272 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4273 writeln!(
4274 code,
4275 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4276 )?;
4277 writeln!(
4278 code,
4279 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4280 )?;
4281 if config.qkv_bias {
4282 writeln!(
4283 code,
4284 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4285 )?;
4286 }
4287 writeln!(
4288 code,
4289 " self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4290 )?;
4291 writeln!(
4292 code,
4293 " self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4294 )?;
4295 writeln!(
4296 code,
4297 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4298 )?;
4299 writeln!(
4300 code,
4301 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4302 )?;
4303 writeln!(
4304 code,
4305 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4306 )?;
4307 writeln!(
4308 code,
4309 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4310 )?;
4311 writeln!(
4312 code,
4313 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4314 )?;
4315 writeln!(
4316 code,
4317 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4318 )?;
4319 writeln!(
4320 code,
4321 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4322 )?;
4323 writeln!(
4324 code,
4325 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4326 )?;
4327 writeln!(code, " }}")?;
4328 writeln!(code, " enc.end_encoding();")?;
4329 writeln!(code, " }}")?;
4330 writeln!(code, " cmd.commit();")?;
4331 writeln!(code, " cmd.wait_until_completed();")?;
4332 writeln!(code, " let d_layers = t_layers.elapsed();")?;
4333 writeln!(code)?;
4334
4335 writeln!(code, " // ── Stage: Final norm + logits (GPU) ──")?;
4337 writeln!(code, " let t_logits = Instant::now();")?;
4338 writeln!(code, " let cmd2 = self.queue.new_command_buffer();")?;
4339 writeln!(code, " {{")?;
4340 writeln!(
4341 code,
4342 " let enc = cmd2.new_compute_command_encoder();"
4343 )?;
4344 writeln!(
4345 code,
4346 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4347 )?;
4348 writeln!(
4349 code,
4350 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4351 )?;
4352 writeln!(code, " enc.end_encoding();")?;
4353 writeln!(code, " }}")?;
4354 writeln!(code, " cmd2.commit();")?;
4355 writeln!(code, " cmd2.wait_until_completed();")?;
4356 writeln!(code, " let d_logits = t_logits.elapsed();")?;
4357 writeln!(code)?;
4358
4359 writeln!(
4361 code,
4362 " eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4363 )?;
4364 writeln!(code, " d_embed.as_secs_f64() * 1000.0,")?;
4365 writeln!(code, " d_layers.as_secs_f64() * 1000.0,")?;
4366 writeln!(code, " d_logits.as_secs_f64() * 1000.0,")?;
4367 writeln!(
4368 code,
4369 " (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4370 )?;
4371 writeln!(code)?;
4372
4373 writeln!(code, " let logits = unsafe {{")?;
4375 writeln!(
4376 code,
4377 " let ptr = self.logits_buf.contents() as *const f32;"
4378 )?;
4379 writeln!(
4380 code,
4381 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4382 )?;
4383 writeln!(code, " }};")?;
4384 writeln!(code)?;
4385 writeln!(code, " self.pos += 1;")?;
4386 writeln!(code, " logits")?;
4387 writeln!(code, " }}")?;
4388 writeln!(code)?;
4389
4390 writeln!(
4392 code,
4393 " /// Asynchronous forward pass for a single prefill token (no logits readback)."
4394 )?;
4395 writeln!(code, " ///")?;
4396 writeln!(
4397 code,
4398 " /// Commits the command buffer without waiting, enabling double-buffered"
4399 )?;
4400 writeln!(
4401 code,
4402 " /// execution: GPU processes token N while CPU encodes token N+1."
4403 )?;
4404 writeln!(
4405 code,
4406 " pub fn forward_prefill(&mut self, token_id: u32) {{"
4407 )?;
4408 writeln!(code, " self.forward_prefill_batch(&[token_id]);")?;
4409 writeln!(code, " }}")?;
4410 writeln!(code)?;
4411
4412 let batch_matmul_fn = if is_q8 {
4415 "dispatch_matmul_q8_batch"
4416 } else if is_q4 {
4417 "dispatch_matmul_q4_batch"
4418 } else {
4419 "dispatch_matmul_batch"
4420 };
4421
4422 writeln!(
4423 code,
4424 " /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4425 )?;
4426 writeln!(code, " ///")?;
4427 writeln!(
4428 code,
4429 " /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4430 )?;
4431 writeln!(
4432 code,
4433 " /// of mat-vec), and batched causal attention with a single GPU dispatch."
4434 )?;
4435 writeln!(
4436 code,
4437 " /// This provides significant speedup during prompt prefill."
4438 )?;
4439 writeln!(
4440 code,
4441 " pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4442 )?;
4443 writeln!(code, " if tokens.is_empty() {{ return; }}")?;
4444 writeln!(
4445 code,
4446 " // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4447 )?;
4448 writeln!(
4449 code,
4450 " // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4451 )?;
4452 writeln!(
4453 code,
4454 " // longer than that must be processed iteratively. The KV cache"
4455 )?;
4456 writeln!(code, " // carries state across chunks via self.pos.")?;
4457 writeln!(
4458 code,
4459 " for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4460 )?;
4461 writeln!(code, " let m = chunk.len();")?;
4462 writeln!(code, " if m == 0 {{ continue; }}")?;
4463 writeln!(code, " let start_pos = self.pos;")?;
4464 writeln!(code)?;
4465 writeln!(code, " // Wait for any pending command buffer")?;
4466 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4467 writeln!(code, " prev.wait_until_completed();")?;
4468 writeln!(code, " }}")?;
4469 writeln!(code)?;
4470
4471 writeln!(
4473 code,
4474 " // Upload token IDs and positions to GPU buffers"
4475 )?;
4476 writeln!(code, " unsafe {{")?;
4477 writeln!(
4478 code,
4479 " let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4480 )?;
4481 writeln!(
4482 code,
4483 " let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4484 )?;
4485 writeln!(code, " for i in 0..m {{")?;
4486 writeln!(code, " *tok_ptr.add(i) = chunk[i];")?;
4487 writeln!(
4488 code,
4489 " *pos_ptr.add(i) = (start_pos + i) as u32;"
4490 )?;
4491 writeln!(code, " }}")?;
4492 writeln!(code, " }}")?;
4493 writeln!(code)?;
4494
4495 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4496 writeln!(code, " {{")?;
4497 writeln!(
4498 code,
4499 " let enc = cmd.new_compute_command_encoder();"
4500 )?;
4501 writeln!(code)?;
4502
4503 writeln!(
4505 code,
4506 " // 1. Batch embedding lookup: copy all token embeddings at once"
4507 )?;
4508 writeln!(
4509 code,
4510 " self.dispatch_copy_embedding_batch(&enc, m);"
4511 )?;
4512 writeln!(
4514 code,
4515 " self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4516 )?;
4517 writeln!(code)?;
4518
4519 writeln!(code, " // 2. Transformer layers")?;
4521 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4522 writeln!(code)?;
4523
4524 writeln!(
4526 code,
4527 " // Batch RMS norm: batch_residual -> batch_hidden"
4528 )?;
4529 writeln!(
4530 code,
4531 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4532 )?;
4533
4534 writeln!(
4536 code,
4537 " // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4538 )?;
4539 writeln!(
4540 code,
4541 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4542 )?;
4543 if config.qkv_bias {
4544 writeln!(
4545 code,
4546 " // Qwen2: broadcast-add QKV bias across all M tokens."
4547 )?;
4548 writeln!(
4549 code,
4550 " self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4551 )?;
4552 }
4553 writeln!(code)?;
4554
4555 let k_float_offset = hidden;
4557 writeln!(
4558 code,
4559 " // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4560 )?;
4561 writeln!(
4562 code,
4563 " self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4564 )?;
4565 writeln!(code)?;
4566
4567 let v_float_offset = hidden + kv_dim;
4569 writeln!(
4570 code,
4571 " // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4572 )?;
4573 writeln!(
4574 code,
4575 " self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4576 )?;
4577 writeln!(code)?;
4578
4579 writeln!(
4581 code,
4582 " // Batched causal attention: one dispatch for all M tokens"
4583 )?;
4584 writeln!(
4585 code,
4586 " self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4587 )?;
4588 writeln!(code)?;
4589
4590 writeln!(code, " // Batched O projection")?;
4592 writeln!(
4593 code,
4594 " self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4595 )?;
4596 writeln!(code)?;
4597
4598 writeln!(
4600 code,
4601 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4602 )?;
4603 writeln!(code)?;
4604
4605 writeln!(
4607 code,
4608 " // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4609 )?;
4610 writeln!(
4611 code,
4612 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4613 )?;
4614 writeln!(
4615 code,
4616 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4617 )?;
4618 writeln!(
4619 code,
4620 " self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4621 )?;
4622 writeln!(
4623 code,
4624 " self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4625 )?;
4626 writeln!(
4627 code,
4628 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4629 )?;
4630 writeln!(code, " }}")?;
4631 writeln!(code)?;
4632
4633 writeln!(
4635 code,
4636 " // Copy last token's residual to single-token buffer for subsequent forward()"
4637 )?;
4638 writeln!(
4639 code,
4640 " self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4641 )?;
4642 writeln!(code)?;
4643 writeln!(code, " enc.end_encoding();")?;
4644 writeln!(code, " }}")?;
4645 writeln!(code)?;
4646
4647 writeln!(code, " cmd.commit();")?;
4648 writeln!(code, " self.prev_cmd = Some(cmd.to_owned());")?;
4649 writeln!(code, " self.pos += m;")?;
4650 writeln!(code, " }} // end for chunk")?;
4651 writeln!(code, " }}")?;
4652 writeln!(code)?;
4653
4654 writeln!(
4656 code,
4657 " /// Reset the model state for a new inference request."
4658 )?;
4659 writeln!(code, " pub fn reset(&mut self) {{")?;
4660 writeln!(code, " self.pos = 0;")?;
4661 writeln!(code, " self.prev_cmd = None;")?;
4662 writeln!(code, " }}")?;
4663 writeln!(code)?;
4664
4665 writeln!(
4667 code,
4668 " // ── Dispatch helpers (append to a shared compute command encoder) ──"
4669 )?;
4670 writeln!(
4671 code,
4672 " // These methods set pipeline state + buffers + dispatch on an existing"
4673 )?;
4674 writeln!(
4675 code,
4676 " // encoder, avoiding per-operation encoder creation overhead."
4677 )?;
4678 writeln!(code)?;
4679
4680 writeln!(
4682 code,
4683 " /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4684 )?;
4685 writeln!(
4686 code,
4687 " fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4688 )?;
4689 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
4690 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
4691 writeln!(
4692 code,
4693 " enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4694 )?;
4695 writeln!(
4696 code,
4697 " enc.set_buffer(0, Some(&self.residual_buf), 0);"
4698 )?;
4699 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
4700 writeln!(
4701 code,
4702 " enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4703 )?;
4704 writeln!(
4705 code,
4706 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4707 )?;
4708 writeln!(
4709 code,
4710 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4711 )?;
4712 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4713 writeln!(
4714 code,
4715 " let grid_size = MTLSize::new(1, 1, 1); // single threadgroup"
4716 )?;
4717 writeln!(
4718 code,
4719 " enc.dispatch_thread_groups(grid_size, tg_size);"
4720 )?;
4721 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4722 writeln!(code, " }}")?;
4723 writeln!(code)?;
4724
4725 writeln!(
4727 code,
4728 " /// Dispatch matrix-vector multiply: weight * input -> output."
4729 )?;
4730 writeln!(
4731 code,
4732 " fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4733 )?;
4734 writeln!(code, " let r: u32 = rows as u32;")?;
4735 writeln!(code, " let c: u32 = cols as u32;")?;
4736 writeln!(
4737 code,
4738 " enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4739 )?;
4740 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4741 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4742 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4743 writeln!(
4744 code,
4745 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4746 )?;
4747 writeln!(
4748 code,
4749 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4750 )?;
4751 writeln!(
4752 code,
4753 " // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4754 )?;
4755 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4756 writeln!(code, " let num_tg = ((rows + 63) / 64) as u64;")?;
4757 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4758 writeln!(
4759 code,
4760 " enc.dispatch_thread_groups(grid_size, tg_size);"
4761 )?;
4762 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4763 writeln!(code, " }}")?;
4764 writeln!(code)?;
4765
4766 writeln!(
4768 code,
4769 " /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4770 )?;
4771 writeln!(
4772 code,
4773 " /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4774 )?;
4775 writeln!(
4776 code,
4777 " fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4778 )?;
4779 writeln!(code, " let r: u32 = rows as u32;")?;
4780 writeln!(code, " let c: u32 = cols as u32;")?;
4781 writeln!(
4782 code,
4783 " enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4784 )?;
4785 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4786 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4787 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4788 writeln!(
4789 code,
4790 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4791 )?;
4792 writeln!(
4793 code,
4794 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4795 )?;
4796 writeln!(
4797 code,
4798 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4799 )?;
4800 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4801 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4802 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4803 writeln!(
4804 code,
4805 " enc.dispatch_thread_groups(grid_size, tg_size);"
4806 )?;
4807 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4808 writeln!(code, " }}")?;
4809 writeln!(code)?;
4810
4811 writeln!(
4813 code,
4814 " /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4815 )?;
4816 writeln!(
4817 code,
4818 " /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4819 )?;
4820 writeln!(
4821 code,
4822 " fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4823 )?;
4824 writeln!(code, " let r: u32 = rows as u32;")?;
4825 writeln!(code, " let c: u32 = cols as u32;")?;
4826 writeln!(
4827 code,
4828 " enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4829 )?;
4830 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4831 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4832 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4833 writeln!(
4834 code,
4835 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4836 )?;
4837 writeln!(
4838 code,
4839 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4840 )?;
4841 writeln!(
4842 code,
4843 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4844 )?;
4845 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4846 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4847 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4848 writeln!(
4849 code,
4850 " enc.dispatch_thread_groups(grid_size, tg_size);"
4851 )?;
4852 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4853 writeln!(code, " }}")?;
4854 writeln!(code)?;
4855
4856 writeln!(code, " /// Dispatch RoPE on a buffer in-place.")?;
4858 writeln!(
4859 code,
4860 " fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4861 )?;
4862 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4863 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4864 writeln!(code, " let p: u32 = pos as u32;")?;
4865 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4866 writeln!(
4867 code,
4868 " let total_pairs = num_heads * (head_dim / 2);"
4869 )?;
4870 writeln!(
4871 code,
4872 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4873 )?;
4874 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
4875 writeln!(
4876 code,
4877 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4878 )?;
4879 writeln!(
4880 code,
4881 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4882 )?;
4883 writeln!(
4884 code,
4885 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4886 )?;
4887 writeln!(
4888 code,
4889 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4890 )?;
4891 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4892 writeln!(
4893 code,
4894 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4895 )?;
4896 writeln!(
4897 code,
4898 " enc.dispatch_thread_groups(grid_size, tg_size);"
4899 )?;
4900 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4901 writeln!(code, " }}")?;
4902 writeln!(code)?;
4903
4904 writeln!(
4906 code,
4907 " /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4908 )?;
4909 writeln!(
4910 code,
4911 " fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4912 )?;
4913 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4914 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4915 writeln!(code, " let p: u32 = pos as u32;")?;
4916 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4917 writeln!(
4918 code,
4919 " let total_pairs = num_heads * (head_dim / 2);"
4920 )?;
4921 writeln!(
4922 code,
4923 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4924 )?;
4925 writeln!(
4926 code,
4927 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
4928 )?;
4929 writeln!(
4930 code,
4931 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4932 )?;
4933 writeln!(
4934 code,
4935 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4936 )?;
4937 writeln!(
4938 code,
4939 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4940 )?;
4941 writeln!(
4942 code,
4943 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4944 )?;
4945 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4946 writeln!(
4947 code,
4948 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4949 )?;
4950 writeln!(
4951 code,
4952 " enc.dispatch_thread_groups(grid_size, tg_size);"
4953 )?;
4954 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4955 writeln!(code, " }}")?;
4956 writeln!(code)?;
4957
4958 writeln!(
4960 code,
4961 " /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4962 )?;
4963 writeln!(
4964 code,
4965 " fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4966 )?;
4967 writeln!(code, " let sl: u32 = seq_len as u32;")?;
4968 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
4969 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
4970 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
4971 writeln!(
4972 code,
4973 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
4974 )?;
4975 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
4976 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
4977 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
4978 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
4979 writeln!(
4980 code,
4981 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4982 )?;
4983 writeln!(
4984 code,
4985 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4986 )?;
4987 writeln!(
4988 code,
4989 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4990 )?;
4991 writeln!(
4992 code,
4993 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4994 )?;
4995 writeln!(
4996 code,
4997 " // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4998 )?;
4999 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5000 writeln!(
5001 code,
5002 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5003 )?;
5004 writeln!(
5005 code,
5006 " enc.dispatch_thread_groups(grid_size, tg_size);"
5007 )?;
5008 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5009 writeln!(code, " }}")?;
5010 writeln!(code)?;
5011
5012 writeln!(
5014 code,
5015 " /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
5016 )?;
5017 writeln!(
5018 code,
5019 " fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
5020 )?;
5021 writeln!(code, " let sl: u32 = seq_len as u32;")?;
5022 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
5023 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
5024 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
5025 writeln!(
5026 code,
5027 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
5028 )?;
5029 writeln!(
5030 code,
5031 " enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5032 )?;
5033 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
5034 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
5035 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
5036 writeln!(
5037 code,
5038 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5039 )?;
5040 writeln!(
5041 code,
5042 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5043 )?;
5044 writeln!(
5045 code,
5046 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5047 )?;
5048 writeln!(
5049 code,
5050 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5051 )?;
5052 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5053 writeln!(
5054 code,
5055 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5056 )?;
5057 writeln!(
5058 code,
5059 " enc.dispatch_thread_groups(grid_size, tg_size);"
5060 )?;
5061 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5062 writeln!(code, " }}")?;
5063 writeln!(code)?;
5064
5065 writeln!(code, " /// Dispatch fused SiLU-multiply kernel.")?;
5067 writeln!(
5068 code,
5069 " fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5070 )?;
5071 writeln!(code, " let count: u32 = n as u32;")?;
5072 writeln!(
5073 code,
5074 " enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5075 )?;
5076 writeln!(code, " enc.set_buffer(0, Some(gate), 0);")?;
5077 writeln!(code, " enc.set_buffer(1, Some(up), 0);")?;
5078 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5079 writeln!(
5080 code,
5081 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5082 )?;
5083 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5084 writeln!(
5085 code,
5086 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5087 )?;
5088 writeln!(
5089 code,
5090 " enc.dispatch_thread_groups(grid_size, tg_size);"
5091 )?;
5092 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5093 writeln!(code, " }}")?;
5094 writeln!(code)?;
5095
5096 writeln!(
5098 code,
5099 " /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5100 )?;
5101 writeln!(
5102 code,
5103 " /// gate_up_buf contains [gate(n), up(n)] contiguously."
5104 )?;
5105 writeln!(
5106 code,
5107 " fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5108 )?;
5109 writeln!(code, " let count: u32 = n as u32;")?;
5110 writeln!(
5111 code,
5112 " enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5113 )?;
5114 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5115 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5116 writeln!(
5117 code,
5118 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5119 )?;
5120 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5121 writeln!(
5122 code,
5123 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5124 )?;
5125 writeln!(
5126 code,
5127 " enc.dispatch_thread_groups(grid_size, tg_size);"
5128 )?;
5129 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5130 writeln!(code, " }}")?;
5131 writeln!(code)?;
5132
5133 writeln!(
5135 code,
5136 " /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5137 )?;
5138 writeln!(
5139 code,
5140 " fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5141 )?;
5142 writeln!(code, " let n: u32 = count as u32;")?;
5143 writeln!(
5144 code,
5145 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5146 )?;
5147 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5148 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5149 writeln!(
5150 code,
5151 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5152 )?;
5153 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5154 writeln!(
5155 code,
5156 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5157 )?;
5158 writeln!(
5159 code,
5160 " enc.dispatch_thread_groups(grid_size, tg_size);"
5161 )?;
5162 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5163 writeln!(code, " }}")?;
5164 writeln!(code)?;
5165
5166 writeln!(
5168 code,
5169 " /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5170 )?;
5171 writeln!(
5172 code,
5173 " /// Used for embedding table lookup (copy a specific row)."
5174 )?;
5175 writeln!(
5176 code,
5177 " fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5178 )?;
5179 writeln!(code, " let off: u32 = src_offset as u32;")?;
5180 writeln!(code, " let n: u32 = count as u32;")?;
5181 writeln!(
5182 code,
5183 " enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5184 )?;
5185 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5186 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5187 writeln!(
5188 code,
5189 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5190 )?;
5191 writeln!(
5192 code,
5193 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5194 )?;
5195 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5196 writeln!(
5197 code,
5198 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5199 )?;
5200 writeln!(
5201 code,
5202 " enc.dispatch_thread_groups(grid_size, tg_size);"
5203 )?;
5204 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5205 writeln!(code, " }}")?;
5206 writeln!(code)?;
5207
5208 writeln!(
5210 code,
5211 " /// Dispatch copy from source at byte offset to destination at float offset."
5212 )?;
5213 writeln!(
5214 code,
5215 " /// Used for KV cache updates from fused QKV buffer."
5216 )?;
5217 writeln!(
5218 code,
5219 " fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5220 )?;
5221 writeln!(code, " let n: u32 = count as u32;")?;
5222 writeln!(
5223 code,
5224 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5225 )?;
5226 writeln!(
5227 code,
5228 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5229 )?;
5230 writeln!(
5231 code,
5232 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5233 )?;
5234 writeln!(
5235 code,
5236 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5237 )?;
5238 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5239 writeln!(
5240 code,
5241 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5242 )?;
5243 writeln!(
5244 code,
5245 " enc.dispatch_thread_groups(grid_size, tg_size);"
5246 )?;
5247 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5248 writeln!(code, " }}")?;
5249 writeln!(code)?;
5250
5251 writeln!(
5253 code,
5254 " /// Dispatch f32 -> f16 copy from src at byte offset to half-typed dst at element offset."
5255 )?;
5256 writeln!(
5257 code,
5258 " /// Used for single-token decode KV cache updates (f32 QKV buf -> f16 KV cache)."
5259 )?;
5260 writeln!(
5261 code,
5262 " fn dispatch_copy_from_offset_f16(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_elem_offset: usize, count: usize) {{"
5263 )?;
5264 writeln!(code, " let n: u32 = count as u32;")?;
5265 writeln!(
5266 code,
5267 " enc.set_compute_pipeline_state(&self.copy_f32_to_f16_offset_pipeline);"
5268 )?;
5269 writeln!(
5270 code,
5271 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5272 )?;
5273 writeln!(
5274 code,
5275 " enc.set_buffer(1, Some(dst), (dst_elem_offset * 2) as u64);"
5276 )?;
5277 writeln!(
5278 code,
5279 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5280 )?;
5281 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5282 writeln!(
5283 code,
5284 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5285 )?;
5286 writeln!(
5287 code,
5288 " enc.dispatch_thread_groups(grid_size, tg_size);"
5289 )?;
5290 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5291 writeln!(code, " }}")?;
5292 writeln!(code)?;
5293
5294 writeln!(
5296 code,
5297 " /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5298 )?;
5299 writeln!(
5300 code,
5301 " /// Used for KV cache updates (write to a specific position in the cache)."
5302 )?;
5303 writeln!(
5304 code,
5305 " fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5306 )?;
5307 writeln!(code, " let n: u32 = count as u32;")?;
5308 writeln!(
5309 code,
5310 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5311 )?;
5312 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5313 writeln!(
5314 code,
5315 " enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5316 )?;
5317 writeln!(
5318 code,
5319 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5320 )?;
5321 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5322 writeln!(
5323 code,
5324 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5325 )?;
5326 writeln!(
5327 code,
5328 " enc.dispatch_thread_groups(grid_size, tg_size);"
5329 )?;
5330 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5331 writeln!(code, " }}")?;
5332 writeln!(code)?;
5333
5334 writeln!(
5336 code,
5337 " /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5338 )?;
5339 writeln!(
5340 code,
5341 " fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5342 )?;
5343 writeln!(code, " let count: u32 = n as u32;")?;
5344 writeln!(
5345 code,
5346 " enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5347 )?;
5348 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5349 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5350 writeln!(
5351 code,
5352 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5353 )?;
5354 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5355 writeln!(
5356 code,
5357 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5358 )?;
5359 writeln!(
5360 code,
5361 " enc.dispatch_thread_groups(grid_size, tg_size);"
5362 )?;
5363 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5364 writeln!(code, " }}")?;
5365 writeln!(code)?;
5366
5367 writeln!(code, " // ── Batched prefill dispatch helpers ──")?;
5369 writeln!(code)?;
5370
5371 writeln!(
5373 code,
5374 " /// Dispatch batched embedding lookup: copy M token embeddings at once."
5375 )?;
5376 writeln!(
5377 code,
5378 " fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5379 )?;
5380 writeln!(code, " let dim: u32 = HIDDEN_SIZE as u32;")?;
5381 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5382 writeln!(
5383 code,
5384 " enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5385 )?;
5386 writeln!(code, " enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5387 writeln!(
5388 code,
5389 " enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5390 )?;
5391 writeln!(
5392 code,
5393 " enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5394 )?;
5395 writeln!(
5396 code,
5397 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5398 )?;
5399 writeln!(
5400 code,
5401 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5402 )?;
5403 writeln!(code, " let total = num_tokens * HIDDEN_SIZE;")?;
5404 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5405 writeln!(
5406 code,
5407 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5408 )?;
5409 writeln!(
5410 code,
5411 " enc.dispatch_thread_groups(grid_size, tg_size);"
5412 )?;
5413 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5414 writeln!(code, " }}")?;
5415 writeln!(code)?;
5416
5417 writeln!(
5419 code,
5420 " /// Dispatch batched RMS norm: normalizes M vectors at once."
5421 )?;
5422 writeln!(
5423 code,
5424 " fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5425 )?;
5426 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
5427 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
5428 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5429 writeln!(
5430 code,
5431 " enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5432 )?;
5433 writeln!(code, " enc.set_buffer(0, Some(input), 0);")?;
5434 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
5435 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5436 writeln!(
5437 code,
5438 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5439 )?;
5440 writeln!(
5441 code,
5442 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5443 )?;
5444 writeln!(
5445 code,
5446 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5447 )?;
5448 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5449 writeln!(
5450 code,
5451 " let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5452 )?;
5453 writeln!(
5454 code,
5455 " enc.dispatch_thread_groups(grid_size, tg_size);"
5456 )?;
5457 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5458 writeln!(code, " }}")?;
5459 writeln!(code)?;
5460
5461 writeln!(
5463 code,
5464 " /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5465 )?;
5466 writeln!(
5467 code,
5468 " fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5469 )?;
5470 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5471 writeln!(code, " let r: u32 = rows as u32;")?;
5472 writeln!(code, " let c: u32 = cols as u32;")?;
5473 writeln!(
5474 code,
5475 " enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5476 )?;
5477 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5478 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5479 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5480 writeln!(
5481 code,
5482 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5483 )?;
5484 writeln!(
5485 code,
5486 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5487 )?;
5488 writeln!(
5489 code,
5490 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5491 )?;
5492 writeln!(
5493 code,
5494 " let row_tgs = (rows + 63) / 64; // 64 rows per threadgroup for f32"
5495 )?;
5496 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5497 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5498 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5499 writeln!(
5500 code,
5501 " enc.dispatch_thread_groups(grid_size, tg_size);"
5502 )?;
5503 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5504 writeln!(code, " }}")?;
5505 writeln!(code)?;
5506
5507 writeln!(
5509 code,
5510 " /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5511 )?;
5512 writeln!(code, " ///")?;
5513 writeln!(
5514 code,
5515 " /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5516 )?;
5517 writeln!(
5518 code,
5519 " /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5520 )?;
5521 writeln!(
5522 code,
5523 " fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5524 )?;
5525 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5526 writeln!(code, " let r: u32 = rows as u32;")?;
5527 writeln!(code, " let c: u32 = cols as u32;")?;
5528 writeln!(
5529 code,
5530 " // Tile sizes must match the Metal shader constants."
5531 )?;
5532 writeln!(code, " const TOKENS_PER_TG_Q8: usize = 4;")?;
5533 writeln!(code, " const MMA_TOK_TILE: usize = 16;")?;
5534 writeln!(code, " const MMA_ROW_TILE: usize = 16;")?;
5535 writeln!(code, " const MMA32_TOK_TILE: usize = 32;")?;
5536 writeln!(code, " const MMA32_ROW_TILE: usize = 32;")?;
5537 writeln!(
5538 code,
5539 " // Hardware matrix-multiply paths (simdgroup_matrix)."
5540 )?;
5541 writeln!(
5542 code,
5543 " // Prefer the large 32×32 tile when the problem supports it — halves"
5544 )?;
5545 writeln!(
5546 code,
5547 " // dispatch count and reuses each weight load across 32 tokens."
5548 )?;
5549 writeln!(
5550 code,
5551 " if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5552 )?;
5553 writeln!(
5554 code,
5555 " // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5556 )?;
5557 writeln!(
5558 code,
5559 " // It wins at moderate prefill lengths where the GPU is wave-starved,"
5560 )?;
5561 writeln!(
5562 code,
5563 " // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5564 )?;
5565 writeln!(
5566 code,
5567 " // case (135M / 360M). Switch at cols >= 2048 — a clean split that"
5568 )?;
5569 writeln!(
5570 code,
5571 " // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5572 )?;
5573 writeln!(
5574 code,
5575 " // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5576 )?;
5577 writeln!(
5578 code,
5579 " // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5580 )?;
5581 writeln!(
5582 code,
5583 " // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5584 )?;
5585 writeln!(code, " let use_h4 = cols >= 2048;")?;
5586 writeln!(code, " let pipe = if use_h4 {{")?;
5587 writeln!(code, " if num_tokens >= 256 {{")?;
5588 writeln!(
5589 code,
5590 " &self.matmul_q8_mma32_hh4_pipeline"
5591 )?;
5592 writeln!(code, " }} else {{")?;
5593 writeln!(
5594 code,
5595 " &self.matmul_q8_mma32_h4_pipeline"
5596 )?;
5597 writeln!(code, " }}")?;
5598 writeln!(code, " }} else {{")?;
5599 writeln!(code, " &self.matmul_q8_mma32_pipeline")?;
5600 writeln!(code, " }};")?;
5601 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
5602 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5603 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5604 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5605 writeln!(
5606 code,
5607 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5608 )?;
5609 writeln!(
5610 code,
5611 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5612 )?;
5613 writeln!(
5614 code,
5615 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5616 )?;
5617 writeln!(code, " let row_tgs = rows / MMA32_ROW_TILE;")?;
5618 writeln!(
5619 code,
5620 " let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5621 )?;
5622 writeln!(
5623 code,
5624 " let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5625 )?;
5626 writeln!(
5627 code,
5628 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5629 )?;
5630 writeln!(
5631 code,
5632 " enc.dispatch_thread_groups(grid_size, tg_size);"
5633 )?;
5634 writeln!(
5635 code,
5636 " }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5637 )?;
5638 writeln!(
5639 code,
5640 " enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5641 )?;
5642 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5643 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5644 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5645 writeln!(
5646 code,
5647 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5648 )?;
5649 writeln!(
5650 code,
5651 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5652 )?;
5653 writeln!(
5654 code,
5655 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5656 )?;
5657 writeln!(code, " let row_tgs = rows / MMA_ROW_TILE;")?;
5658 writeln!(
5659 code,
5660 " let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5661 )?;
5662 writeln!(
5663 code,
5664 " let tg_size = MTLSize::new(128, 1, 1); // 4 simdgroups × 32 lanes"
5665 )?;
5666 writeln!(
5667 code,
5668 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5669 )?;
5670 writeln!(
5671 code,
5672 " enc.dispatch_thread_groups(grid_size, tg_size);"
5673 )?;
5674 writeln!(code, " }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5675 writeln!(
5676 code,
5677 " enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5678 )?;
5679 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5680 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5681 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5682 writeln!(
5683 code,
5684 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5685 )?;
5686 writeln!(
5687 code,
5688 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5689 )?;
5690 writeln!(
5691 code,
5692 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5693 )?;
5694 writeln!(
5695 code,
5696 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5697 )?;
5698 writeln!(
5699 code,
5700 " let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5701 )?;
5702 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5703 writeln!(
5704 code,
5705 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5706 )?;
5707 writeln!(
5708 code,
5709 " enc.dispatch_thread_groups(grid_size, tg_size);"
5710 )?;
5711 writeln!(code, " }} else {{")?;
5712 writeln!(
5713 code,
5714 " enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5715 )?;
5716 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5717 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5718 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5719 writeln!(
5720 code,
5721 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5722 )?;
5723 writeln!(
5724 code,
5725 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5726 )?;
5727 writeln!(
5728 code,
5729 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5730 )?;
5731 writeln!(
5732 code,
5733 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5734 )?;
5735 writeln!(
5736 code,
5737 " let num_tg = (row_tgs * num_tokens) as u64;"
5738 )?;
5739 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5740 writeln!(
5741 code,
5742 " let grid_size = MTLSize::new(num_tg, 1, 1);"
5743 )?;
5744 writeln!(
5745 code,
5746 " enc.dispatch_thread_groups(grid_size, tg_size);"
5747 )?;
5748 writeln!(code, " }}")?;
5749 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5750 writeln!(code, " }}")?;
5751 writeln!(code)?;
5752
5753 writeln!(
5755 code,
5756 " /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5757 )?;
5758 writeln!(
5759 code,
5760 " fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5761 )?;
5762 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5763 writeln!(code, " let r: u32 = rows as u32;")?;
5764 writeln!(code, " let c: u32 = cols as u32;")?;
5765 writeln!(
5766 code,
5767 " enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5768 )?;
5769 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5770 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5771 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5772 writeln!(
5773 code,
5774 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5775 )?;
5776 writeln!(
5777 code,
5778 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5779 )?;
5780 writeln!(
5781 code,
5782 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5783 )?;
5784 writeln!(
5785 code,
5786 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q4"
5787 )?;
5788 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5789 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5790 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5791 writeln!(
5792 code,
5793 " enc.dispatch_thread_groups(grid_size, tg_size);"
5794 )?;
5795 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5796 writeln!(code, " }}")?;
5797 writeln!(code)?;
5798
5799 if config.qkv_bias {
5801 writeln!(
5802 code,
5803 " /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5804 )?;
5805 writeln!(
5806 code,
5807 " fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5808 )?;
5809 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5810 writeln!(code, " let r: u32 = rows as u32;")?;
5811 writeln!(
5812 code,
5813 " enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5814 )?;
5815 writeln!(code, " enc.set_buffer(0, Some(out), 0);")?;
5816 writeln!(code, " enc.set_buffer(1, Some(bias), 0);")?;
5817 writeln!(
5818 code,
5819 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5820 )?;
5821 writeln!(
5822 code,
5823 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5824 )?;
5825 writeln!(code, " let total = (num_tokens * rows) as u64;")?;
5826 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5827 writeln!(
5828 code,
5829 " let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5830 )?;
5831 writeln!(
5832 code,
5833 " enc.dispatch_thread_groups(grid_size, tg_size);"
5834 )?;
5835 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5836 writeln!(code, " }}")?;
5837 writeln!(code)?;
5838 }
5839
5840 writeln!(
5842 code,
5843 " /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5844 )?;
5845 writeln!(
5846 code,
5847 " /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5848 )?;
5849 writeln!(
5850 code,
5851 " /// `row_stride` is the number of floats per token row in the batch buffer."
5852 )?;
5853 writeln!(
5854 code,
5855 " fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5856 )?;
5857 writeln!(code, " let nh: u32 = num_heads as u32;")?;
5858 writeln!(code, " let hd: u32 = head_dim as u32;")?;
5859 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
5860 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5861 writeln!(
5862 code,
5863 " let pairs_per_token = num_heads * (head_dim / 2);"
5864 )?;
5865 writeln!(
5866 code,
5867 " let total_pairs = num_tokens * pairs_per_token;"
5868 )?;
5869 writeln!(
5882 code,
5883 " // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5884 )?;
5885 writeln!(code, " for t in 0..num_tokens {{")?;
5886 writeln!(
5887 code,
5888 " let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5889 )?;
5890 writeln!(
5891 code,
5892 " let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5893 )?;
5894 writeln!(
5895 code,
5896 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
5897 )?;
5898 writeln!(
5899 code,
5900 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
5901 )?;
5902 writeln!(
5903 code,
5904 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5905 )?;
5906 writeln!(
5907 code,
5908 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5909 )?;
5910 writeln!(
5911 code,
5912 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5913 )?;
5914 writeln!(
5915 code,
5916 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5917 )?;
5918 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5919 writeln!(
5920 code,
5921 " let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5922 )?;
5923 writeln!(
5924 code,
5925 " enc.dispatch_thread_groups(grid_size, tg_size);"
5926 )?;
5927 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5928 writeln!(code, " }}")?;
5929 writeln!(code, " }}")?;
5930 writeln!(code)?;
5931
5932 writeln!(
5934 code,
5935 " /// Dispatch batched fused SiLU-multiply for M tokens."
5936 )?;
5937 writeln!(
5938 code,
5939 " fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5940 )?;
5941 writeln!(code, " let count: u32 = n as u32;")?;
5942 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5943 writeln!(
5944 code,
5945 " enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5946 )?;
5947 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5948 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5949 writeln!(
5950 code,
5951 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5952 )?;
5953 writeln!(
5954 code,
5955 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5956 )?;
5957 writeln!(code, " let total = num_tokens * n;")?;
5958 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5959 writeln!(
5960 code,
5961 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5962 )?;
5963 writeln!(
5964 code,
5965 " enc.dispatch_thread_groups(grid_size, tg_size);"
5966 )?;
5967 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5968 writeln!(code, " }}")?;
5969 writeln!(code)?;
5970
5971 writeln!(
5973 code,
5974 " /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5975 )?;
5976 writeln!(
5977 code,
5978 " fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5979 )?;
5980 writeln!(code, " let count: u32 = total_n as u32;")?;
5981 writeln!(
5982 code,
5983 " enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5984 )?;
5985 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5986 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5987 writeln!(
5988 code,
5989 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5990 )?;
5991 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5992 writeln!(
5993 code,
5994 " let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5995 )?;
5996 writeln!(
5997 code,
5998 " enc.dispatch_thread_groups(grid_size, tg_size);"
5999 )?;
6000 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6001 writeln!(code, " }}")?;
6002 writeln!(code)?;
6003
6004 writeln!(
6006 code,
6007 " /// Copy src to dst using compute copy kernel (for batch residual init)."
6008 )?;
6009 writeln!(
6010 code,
6011 " fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
6012 )?;
6013 writeln!(code, " let n: u32 = count as u32;")?;
6014 writeln!(
6015 code,
6016 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6017 )?;
6018 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6019 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
6020 writeln!(
6021 code,
6022 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6023 )?;
6024 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6025 writeln!(
6026 code,
6027 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6028 )?;
6029 writeln!(
6030 code,
6031 " enc.dispatch_thread_groups(grid_size, tg_size);"
6032 )?;
6033 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6034 writeln!(code, " }}")?;
6035 writeln!(code)?;
6036
6037 writeln!(
6039 code,
6040 " /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
6041 )?;
6042 writeln!(
6043 code,
6044 " fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6045 )?;
6046 writeln!(code, " let n: u32 = count as u32;")?;
6047 writeln!(
6048 code,
6049 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6050 )?;
6051 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6052 writeln!(
6053 code,
6054 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6055 )?;
6056 writeln!(
6057 code,
6058 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6059 )?;
6060 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6061 writeln!(
6062 code,
6063 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6064 )?;
6065 writeln!(
6066 code,
6067 " enc.dispatch_thread_groups(grid_size, tg_size);"
6068 )?;
6069 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6070 writeln!(code, " }}")?;
6071 writeln!(code)?;
6072
6073 writeln!(
6075 code,
6076 " /// Copy from src at byte offset to dst at float offset."
6077 )?;
6078 writeln!(
6079 code,
6080 " fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6081 )?;
6082 writeln!(code, " let n: u32 = count as u32;")?;
6083 writeln!(
6084 code,
6085 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6086 )?;
6087 writeln!(
6088 code,
6089 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6090 )?;
6091 writeln!(
6092 code,
6093 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6094 )?;
6095 writeln!(
6096 code,
6097 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6098 )?;
6099 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6100 writeln!(
6101 code,
6102 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6103 )?;
6104 writeln!(
6105 code,
6106 " enc.dispatch_thread_groups(grid_size, tg_size);"
6107 )?;
6108 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6109 writeln!(code, " }}")?;
6110 writeln!(code)?;
6111
6112 writeln!(
6114 code,
6115 " /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6116 )?;
6117 writeln!(
6118 code,
6119 " fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6120 )?;
6121 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6122 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6123 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6124 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6125 writeln!(code, " let so: u32 = src_offset as u32;")?;
6126 writeln!(
6127 code,
6128 " enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6129 )?;
6130 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6131 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
6132 writeln!(
6133 code,
6134 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6135 )?;
6136 writeln!(
6137 code,
6138 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6139 )?;
6140 writeln!(
6141 code,
6142 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6143 )?;
6144 writeln!(
6145 code,
6146 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6147 )?;
6148 writeln!(
6149 code,
6150 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6151 )?;
6152 writeln!(code, " let total = num_tokens * kv_dim;")?;
6153 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6154 writeln!(
6155 code,
6156 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6157 )?;
6158 writeln!(
6159 code,
6160 " enc.dispatch_thread_groups(grid_size, tg_size);"
6161 )?;
6162 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6163 writeln!(code, " }}")?;
6164 writeln!(code)?;
6165
6166 writeln!(
6168 code,
6169 " /// Dispatch batched causal attention: one dispatch for all M tokens."
6170 )?;
6171 writeln!(
6172 code,
6173 " fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6174 )?;
6175 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6176 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6177 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
6178 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6179 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6180 writeln!(code, " let qs: u32 = q_stride as u32;")?;
6181 writeln!(code, " let max_seq = base_pos + num_tokens;")?;
6191 writeln!(code, " let _ = max_seq;")?;
6192 writeln!(
6193 code,
6194 " let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6195 )?;
6196 writeln!(code, " .map(|v| v == \"0\").unwrap_or(false);")?;
6197 writeln!(
6198 code,
6199 " let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6200 )?;
6201 writeln!(code, " if use_mma_flash {{")?;
6202 writeln!(
6203 code,
6204 " let pipe = &self.attention_mma_flash_batch_pipeline;"
6205 )?;
6206 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6207 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6208 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6209 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6210 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6211 writeln!(
6212 code,
6213 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6214 )?;
6215 writeln!(
6216 code,
6217 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6218 )?;
6219 writeln!(
6220 code,
6221 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6222 )?;
6223 writeln!(
6224 code,
6225 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6226 )?;
6227 writeln!(
6228 code,
6229 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6230 )?;
6231 writeln!(
6232 code,
6233 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6234 )?;
6235 writeln!(
6236 code,
6237 " // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6238 )?;
6239 writeln!(code, " let tg_size = MTLSize::new(128, 1, 1);")?;
6240 writeln!(
6241 code,
6242 " let q_blocks = ((num_tokens + 7) / 8) as u64;"
6243 )?;
6244 writeln!(
6245 code,
6246 " let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6247 )?;
6248 writeln!(
6249 code,
6250 " enc.dispatch_thread_groups(grid_size, tg_size);"
6251 )?;
6252 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6253 writeln!(code, " return;")?;
6254 writeln!(code, " }}")?;
6255 writeln!(code, " let pipe = &self.attention_batch_pipeline;")?;
6256 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6257 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6258 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6259 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6260 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6261 writeln!(
6262 code,
6263 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6264 )?;
6265 writeln!(
6266 code,
6267 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6268 )?;
6269 writeln!(
6270 code,
6271 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6272 )?;
6273 writeln!(
6274 code,
6275 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6276 )?;
6277 writeln!(
6278 code,
6279 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6280 )?;
6281 writeln!(
6282 code,
6283 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6284 )?;
6285 writeln!(
6286 code,
6287 " // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6288 )?;
6289 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6290 writeln!(
6291 code,
6292 " let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6293 )?;
6294 writeln!(
6295 code,
6296 " enc.dispatch_thread_groups(grid_size, tg_size);"
6297 )?;
6298 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6299 writeln!(code, " }}")?;
6300 writeln!(code)?;
6301
6302 writeln!(
6304 code,
6305 " /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6306 )?;
6307 writeln!(
6308 code,
6309 " /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6310 )?;
6311 writeln!(
6312 code,
6313 " fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6314 )?;
6315 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6316 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6317 writeln!(code, " let nq: u32 = NUM_HEADS as u32;")?;
6318 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6319 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6320 writeln!(code, " let qs: u32 = qkv_stride as u32;")?;
6321 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
6322 writeln!(
6323 code,
6324 " enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6325 )?;
6326 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
6327 writeln!(
6328 code,
6329 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6330 )?;
6331 writeln!(
6332 code,
6333 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6334 )?;
6335 writeln!(
6336 code,
6337 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6338 )?;
6339 writeln!(
6340 code,
6341 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6342 )?;
6343 writeln!(
6344 code,
6345 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6346 )?;
6347 writeln!(
6348 code,
6349 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6350 )?;
6351 writeln!(
6352 code,
6353 " enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6354 )?;
6355 writeln!(
6356 code,
6357 " let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6358 )?;
6359 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6360 writeln!(
6361 code,
6362 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6363 )?;
6364 writeln!(
6365 code,
6366 " enc.dispatch_thread_groups(grid_size, tg_size);"
6367 )?;
6368 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6369 writeln!(code, " }}")?;
6370 writeln!(code)?;
6371
6372 writeln!(
6374 code,
6375 " /// Dispatch fused K+V cache copy in one kernel launch."
6376 )?;
6377 writeln!(
6378 code,
6379 " /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6380 )?;
6381 writeln!(
6382 code,
6383 " fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6384 )?;
6385 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6386 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6387 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6388 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6389 writeln!(code, " let ko: u32 = k_offset as u32;")?;
6390 writeln!(code, " let vo: u32 = v_offset as u32;")?;
6391 writeln!(
6392 code,
6393 " enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6394 )?;
6395 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6396 writeln!(code, " enc.set_buffer(1, Some(k_dst), 0);")?;
6397 writeln!(code, " enc.set_buffer(2, Some(v_dst), 0);")?;
6398 writeln!(
6399 code,
6400 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6401 )?;
6402 writeln!(
6403 code,
6404 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6405 )?;
6406 writeln!(
6407 code,
6408 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6409 )?;
6410 writeln!(
6411 code,
6412 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6413 )?;
6414 writeln!(
6415 code,
6416 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6417 )?;
6418 writeln!(
6419 code,
6420 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6421 )?;
6422 writeln!(
6423 code,
6424 " let total = num_tokens * kv_dim * 2; // K + V"
6425 )?;
6426 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6427 writeln!(
6428 code,
6429 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6430 )?;
6431 writeln!(
6432 code,
6433 " enc.dispatch_thread_groups(grid_size, tg_size);"
6434 )?;
6435 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6436 writeln!(code, " }}")?;
6437
6438 writeln!(code, "}}")?;
6439 writeln!(code)?;
6440
6441 Ok(())
6442}
6443
6444fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6445 writeln!(
6446 code,
6447 "// ── Helper functions ──────────────────────────────────"
6448 )?;
6449 writeln!(code)?;
6450 writeln!(
6451 code,
6452 "/// Create a compute pipeline from a named function in the Metal library."
6453 )?;
6454 writeln!(
6455 code,
6456 "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6457 )?;
6458 writeln!(
6459 code,
6460 " let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6461 )?;
6462 writeln!(
6463 code,
6464 " device.new_compute_pipeline_state_with_function(&func)"
6465 )?;
6466 writeln!(
6467 code,
6468 " .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6469 )?;
6470 writeln!(code, "}}")?;
6471 writeln!(code)?;
6472
6473 Ok(())
6474}
6475
6476fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6481 let _sanitized = sanitize_name(model_name);
6482 let _vocab = config.vocab_size;
6483
6484 let mut code = String::with_capacity(16 * 1024);
6485 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6486 writeln!(
6487 code,
6488 "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6489 )?;
6490 writeln!(code)?;
6491 writeln!(code, "mod model;")?;
6492 writeln!(code)?;
6493 writeln!(code, "use std::io::Write;")?;
6494 writeln!(code, "use std::time::Instant;")?;
6495 writeln!(code, "use serde::Deserialize;")?;
6496 writeln!(code)?;
6497
6498 writeln!(code, "fn main() {{")?;
6500 writeln!(
6501 code,
6502 " let args: Vec<String> = std::env::args().collect();"
6503 )?;
6504 writeln!(code)?;
6505 writeln!(
6506 code,
6507 " // Detect --serve mode (only requires weights + tokenizer)"
6508 )?;
6509 writeln!(
6510 code,
6511 " let serve_mode = args.iter().any(|a| a == \"--serve\");"
6512 )?;
6513 writeln!(code)?;
6514 writeln!(code, " if !serve_mode && args.len() < 4 {{")?;
6515 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6516 writeln!(code, " eprintln!(\" {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6517 writeln!(code, " std::process::exit(1);")?;
6518 writeln!(code, " }}")?;
6519 writeln!(code)?;
6520 writeln!(code, " if serve_mode && args.len() < 3 {{")?;
6521 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6522 writeln!(code, " std::process::exit(1);")?;
6523 writeln!(code, " }}")?;
6524 writeln!(code)?;
6525 writeln!(code, " let weights_path = &args[1];")?;
6526 writeln!(code, " let tokenizer_path = &args[2];")?;
6527 writeln!(code)?;
6528 writeln!(code, " // Parse optional flags")?;
6529 writeln!(code, " let mut max_tokens: usize = 128;")?;
6530 writeln!(code, " let mut port: u16 = 8080;")?;
6531 writeln!(
6532 code,
6533 " let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6534 )?;
6535 writeln!(
6536 code,
6537 " let profile = args.iter().any(|a| a == \"--profile\");"
6538 )?;
6539 writeln!(code, " let mut i = 3;")?;
6540 writeln!(code, " while i < args.len() {{")?;
6541 writeln!(
6542 code,
6543 " if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6544 )?;
6545 writeln!(
6546 code,
6547 " max_tokens = args[i + 1].parse().unwrap_or(128);"
6548 )?;
6549 writeln!(code, " i += 2;")?;
6550 writeln!(
6551 code,
6552 " }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6553 )?;
6554 writeln!(
6555 code,
6556 " port = args[i + 1].parse().unwrap_or(8080);"
6557 )?;
6558 writeln!(code, " i += 2;")?;
6559 writeln!(code, " }} else if args[i] == \"--serve\" {{")?;
6560 writeln!(code, " i += 1;")?;
6561 writeln!(code, " }} else if args[i] == \"--profile\" {{")?;
6562 writeln!(code, " i += 1;")?;
6563 writeln!(code, " }} else {{")?;
6564 writeln!(code, " i += 1;")?;
6565 writeln!(code, " }}")?;
6566 writeln!(code, " }}")?;
6567 writeln!(code)?;
6568
6569 writeln!(
6571 code,
6572 " // Memory-map weights for zero-copy loading on Apple Silicon"
6573 )?;
6574 writeln!(
6575 code,
6576 " let weights_file = std::fs::File::open(weights_path)"
6577 )?;
6578 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6579 writeln!(
6580 code,
6581 " let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6582 )?;
6583 writeln!(code)?;
6584 writeln!(code, " // Load tokenizer")?;
6585 writeln!(
6586 code,
6587 " let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6588 )?;
6589 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6590 writeln!(code)?;
6591 writeln!(code, " // Create Metal model")?;
6592 writeln!(code, " eprintln!(\"Loading model onto Metal GPU...\");")?;
6593 writeln!(
6594 code,
6595 " let mut model = model::MetalModel::new(&weights_mmap);"
6596 )?;
6597 writeln!(code)?;
6598
6599 writeln!(code, " if serve_mode {{")?;
6601 writeln!(code, " serve(model, tokenizer, port);")?;
6602 writeln!(code, " }} else {{")?;
6603 writeln!(code, " let prompt = &args[3];")?;
6604 writeln!(
6605 code,
6606 " cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6607 )?;
6608 writeln!(code, " }}")?;
6609 writeln!(code, "}}")?;
6610 writeln!(code)?;
6611
6612 writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6614 writeln!(code, " // Tokenize prompt")?;
6615 writeln!(code, " let encoding = tokenizer.encode(prompt, true)")?;
6616 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6617 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6618 writeln!(code)?;
6619 writeln!(
6620 code,
6621 " // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6622 )?;
6623 writeln!(
6624 code,
6625 " // Uses double-buffered batch dispatch for GPU-efficient matmul."
6626 )?;
6627 writeln!(
6628 code,
6629 " // The last token uses synchronous forward() to get logits."
6630 )?;
6631 writeln!(code, " let prompt_len = prompt_tokens.len();")?;
6632 writeln!(code, " let prefill_start = Instant::now();")?;
6633 writeln!(code, " let logits = if prompt_len > 1 {{")?;
6634 writeln!(
6635 code,
6636 " model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6637 )?;
6638 writeln!(code, " model.forward(prompt_tokens[prompt_len - 1])")?;
6639 writeln!(code, " }} else {{")?;
6640 writeln!(code, " model.forward(prompt_tokens[0])")?;
6641 writeln!(code, " }};")?;
6642 writeln!(
6643 code,
6644 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6645 )?;
6646 writeln!(code, " let prefill_tokens = prompt_tokens.len();")?;
6647 writeln!(
6648 code,
6649 " eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6650 )?;
6651 writeln!(
6652 code,
6653 " prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6654 )?;
6655 writeln!(code)?;
6656 writeln!(code, " // Generate tokens")?;
6657 writeln!(code, " let mut next_token = argmax(&logits);")?;
6658 writeln!(code, " let gen_start = Instant::now();")?;
6659 writeln!(code, " let mut generated_count: usize = 0;")?;
6660 writeln!(code)?;
6661 writeln!(code, " for _ in 0..max_tokens {{")?;
6662 writeln!(
6663 code,
6664 " if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6665 )?;
6666 writeln!(code, " if !quiet {{")?;
6667 writeln!(code, " print!(\"{{}}\", text);")?;
6668 writeln!(code, " std::io::stdout().flush().ok();")?;
6669 writeln!(code, " }}")?;
6670 writeln!(code, " }}")?;
6671 writeln!(code, " generated_count += 1;")?;
6672 writeln!(code)?;
6673 writeln!(
6674 code,
6675 " // Use profiling forward for first token when --profile is set"
6676 )?;
6677 writeln!(
6678 code,
6679 " let logits = if profile && generated_count == 1 {{"
6680 )?;
6681 writeln!(code, " model.forward_profile(next_token)")?;
6682 writeln!(code, " }} else {{")?;
6683 writeln!(code, " model.forward(next_token)")?;
6684 writeln!(code, " }};")?;
6685 writeln!(code, " next_token = argmax(&logits);")?;
6686 writeln!(code)?;
6687 writeln!(code, " // Stop on EOS (token 2 for most models)")?;
6688 writeln!(code, " if next_token == 2 {{")?;
6689 writeln!(code, " break;")?;
6690 writeln!(code, " }}")?;
6691 writeln!(code)?;
6692 writeln!(
6693 code,
6694 " // Yield between tokens to reduce sustained GPU thermal load."
6695 )?;
6696 writeln!(
6697 code,
6698 " // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6699 )?;
6700 writeln!(
6701 code,
6702 " // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6703 )?;
6704 writeln!(
6705 code,
6706 " // briefly, providing a micro-break that helps sustain peak throughput."
6707 )?;
6708 writeln!(code, " std::thread::yield_now();")?;
6709 writeln!(code, " }}")?;
6710 writeln!(code, " if !quiet {{")?;
6711 writeln!(code, " println!();")?;
6712 writeln!(code, " }}")?;
6713 writeln!(
6714 code,
6715 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6716 )?;
6717 writeln!(
6718 code,
6719 " eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6720 )?;
6721 writeln!(
6722 code,
6723 " generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6724 )?;
6725 writeln!(code, "}}")?;
6726 writeln!(code)?;
6727
6728 writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6730 writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6731 writeln!(code, " logits.iter()")?;
6732 writeln!(code, " .enumerate()")?;
6733 writeln!(
6734 code,
6735 " .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6736 )?;
6737 writeln!(code, " .map(|(i, _)| i as u32)")?;
6738 writeln!(code, " .unwrap_or(0)")?;
6739 writeln!(code, "}}")?;
6740 writeln!(code)?;
6741
6742 writeln!(
6744 code,
6745 "// -----------------------------------------------------------------------"
6746 )?;
6747 writeln!(code, "// OpenAI-compatible API server")?;
6748 writeln!(
6749 code,
6750 "// -----------------------------------------------------------------------"
6751 )?;
6752 writeln!(code)?;
6753 writeln!(code, "#[derive(Deserialize)]")?;
6754 writeln!(code, "struct ChatRequest {{")?;
6755 writeln!(code, " messages: Vec<ChatMessage>,")?;
6756 writeln!(code, " #[serde(default)]")?;
6757 writeln!(code, " stream: Option<bool>,")?;
6758 writeln!(code, " #[serde(default)]")?;
6759 writeln!(code, " max_tokens: Option<usize>,")?;
6760 writeln!(code, " #[serde(default)]")?;
6761 writeln!(code, " temperature: Option<f32>,")?;
6762 writeln!(code, " #[serde(default)]")?;
6763 writeln!(code, " model: Option<String>,")?;
6764 writeln!(code, "}}")?;
6765 writeln!(code)?;
6766 writeln!(code, "#[derive(Deserialize)]")?;
6767 writeln!(code, "struct ChatMessage {{")?;
6768 writeln!(code, " role: String,")?;
6769 writeln!(code, " content: String,")?;
6770 writeln!(code, "}}")?;
6771 writeln!(code)?;
6772
6773 writeln!(
6775 code,
6776 "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6777 )?;
6778 writeln!(code, " let mut prompt = String::new();")?;
6779 writeln!(code, " for msg in messages {{")?;
6780 writeln!(code, " prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6781 writeln!(code, " }}")?;
6782 writeln!(code, " prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6783 writeln!(code, " prompt")?;
6784 writeln!(code, "}}")?;
6785 writeln!(code)?;
6786
6787 writeln!(
6789 code,
6790 "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6791 )?;
6792 writeln!(code, " let len = tokens.len();")?;
6793 writeln!(code, " if len > 1 {{")?;
6794 writeln!(
6795 code,
6796 " model.forward_prefill_batch(&tokens[..len - 1]);"
6797 )?;
6798 writeln!(code, " }}")?;
6799 writeln!(code, " model.forward(tokens[len - 1])")?;
6800 writeln!(code, "}}")?;
6801 writeln!(code)?;
6802
6803 writeln!(
6805 code,
6806 "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6807 )?;
6808 writeln!(code, " let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6809 writeln!(code, " let server = tiny_http::Server::http(&addr)")?;
6810 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6811 writeln!(
6812 code,
6813 " eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6814 )?;
6815 writeln!(code, " eprintln!(\"Endpoints:\");")?;
6816 writeln!(code, " eprintln!(\" POST /v1/chat/completions\");")?;
6817 writeln!(code, " eprintln!(\" GET /v1/models\");")?;
6818 writeln!(code, " eprintln!(\" GET /health\");")?;
6819 writeln!(code)?;
6820 writeln!(code, " for request in server.incoming_requests() {{")?;
6821 writeln!(code, " let method = request.method().to_string();")?;
6822 writeln!(code, " let url = request.url().to_string();")?;
6823 writeln!(code)?;
6824 writeln!(code, " match (method.as_str(), url.as_str()) {{")?;
6825
6826 writeln!(
6828 code,
6829 " (\"POST\", \"/v1/chat/completions\") => {{"
6830 )?;
6831 writeln!(
6832 code,
6833 " handle_chat_completion(&mut model, &tokenizer, request);"
6834 )?;
6835 writeln!(code, " }}")?;
6836
6837 writeln!(code, " (\"GET\", \"/v1/models\") => {{")?;
6839 writeln!(code, " let body = serde_json::json!({{")?;
6840 writeln!(code, " \"object\": \"list\",")?;
6841 writeln!(code, " \"data\": [{{")?;
6842 writeln!(code, " \"id\": \"forgellm-metal\",")?;
6843 writeln!(code, " \"object\": \"model\",")?;
6844 writeln!(code, " \"owned_by\": \"forgellm\"")?;
6845 writeln!(code, " }}]")?;
6846 writeln!(code, " }});")?;
6847 writeln!(
6848 code,
6849 " let resp = tiny_http::Response::from_string(body.to_string())"
6850 )?;
6851 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6852 writeln!(code, " request.respond(resp).ok();")?;
6853 writeln!(code, " }}")?;
6854
6855 writeln!(code, " (\"GET\", \"/health\") => {{")?;
6857 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6858 writeln!(code, " request.respond(resp).ok();")?;
6859 writeln!(code, " }}")?;
6860
6861 writeln!(code, " _ => {{")?;
6863 writeln!(
6864 code,
6865 " let resp = tiny_http::Response::from_string(\"Not Found\")"
6866 )?;
6867 writeln!(code, " .with_status_code(404);")?;
6868 writeln!(code, " request.respond(resp).ok();")?;
6869 writeln!(code, " }}")?;
6870 writeln!(code, " }}")?;
6871 writeln!(code, " }}")?;
6872 writeln!(code, "}}")?;
6873 writeln!(code)?;
6874
6875 writeln!(code, "fn handle_chat_completion(")?;
6877 writeln!(code, " model: &mut model::MetalModel,")?;
6878 writeln!(code, " tokenizer: &tokenizers::Tokenizer,")?;
6879 writeln!(code, " mut request: tiny_http::Request,")?;
6880 writeln!(code, ") {{")?;
6881 writeln!(code, " // Read request body")?;
6882 writeln!(code, " let mut body = String::new();")?;
6883 writeln!(
6884 code,
6885 " if request.as_reader().read_to_string(&mut body).is_err() {{"
6886 )?;
6887 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6888 writeln!(code, " .with_status_code(400);")?;
6889 writeln!(code, " request.respond(resp).ok();")?;
6890 writeln!(code, " return;")?;
6891 writeln!(code, " }}")?;
6892 writeln!(code)?;
6893 writeln!(code, " // Parse JSON")?;
6894 writeln!(
6895 code,
6896 " let req: ChatRequest = match serde_json::from_str(&body) {{"
6897 )?;
6898 writeln!(code, " Ok(r) => r,")?;
6899 writeln!(code, " Err(e) => {{")?;
6900 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6901 writeln!(code, " .with_status_code(400);")?;
6902 writeln!(code, " request.respond(resp).ok();")?;
6903 writeln!(code, " return;")?;
6904 writeln!(code, " }}")?;
6905 writeln!(code, " }};")?;
6906 writeln!(code)?;
6907 writeln!(
6908 code,
6909 " let prompt = format_chat_messages(&req.messages);"
6910 )?;
6911 writeln!(
6912 code,
6913 " let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6914 )?;
6915 writeln!(code, " Ok(e) => e,")?;
6916 writeln!(code, " Err(e) => {{")?;
6917 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6918 writeln!(code, " .with_status_code(500);")?;
6919 writeln!(code, " request.respond(resp).ok();")?;
6920 writeln!(code, " return;")?;
6921 writeln!(code, " }}")?;
6922 writeln!(code, " }};")?;
6923 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6924 writeln!(code, " let stream = req.stream.unwrap_or(false);")?;
6925 writeln!(code, " let max_tokens = req.max_tokens.unwrap_or(256);")?;
6926 writeln!(
6927 code,
6928 " let _temperature = req.temperature.unwrap_or(1.0);"
6929 )?;
6930 writeln!(code)?;
6931
6932 writeln!(code, " model.reset();")?;
6934 writeln!(code)?;
6935
6936 writeln!(code, " let prefill_start = Instant::now();")?;
6938 writeln!(code, " let logits = prefill(model, prompt_tokens);")?;
6939 writeln!(
6940 code,
6941 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6942 )?;
6943 writeln!(code, " let prefill_count = prompt_tokens.len();")?;
6944 writeln!(code, " let mut next_token = argmax(&logits);")?;
6945 writeln!(code)?;
6946
6947 writeln!(code, " if stream {{")?;
6948
6949 writeln!(
6951 code,
6952 " // SSE streaming: generate tokens and build SSE body"
6953 )?;
6954 writeln!(code, " let gen_start = Instant::now();")?;
6955 writeln!(code, " let mut generated_count: usize = 0;")?;
6956 writeln!(code, " let mut sse_body = String::new();")?;
6957 writeln!(code, " for _ in 0..max_tokens {{")?;
6958 writeln!(code, " if next_token == 2 {{ break; }}")?;
6959 writeln!(
6960 code,
6961 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6962 )?;
6963 writeln!(
6964 code,
6965 " let escaped = serde_json::to_string(&text).unwrap_or_default();"
6966 )?;
6967 writeln!(
6968 code,
6969 " // escaped includes surrounding quotes, strip them"
6970 )?;
6971 writeln!(
6972 code,
6973 " let inner = &escaped[1..escaped.len()-1];"
6974 )?;
6975 writeln!(code, " sse_body.push_str(&format!(")?;
6976 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6977 writeln!(code, " inner")?;
6978 writeln!(code, " ));")?;
6979 writeln!(code, " }}")?;
6980 writeln!(code, " generated_count += 1;")?;
6981 writeln!(code, " let logits = model.forward(next_token);")?;
6982 writeln!(code, " next_token = argmax(&logits);")?;
6983 writeln!(code, " }}")?;
6984 writeln!(
6985 code,
6986 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6987 )?;
6988 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6989 writeln!(code, " let gen_time_ms = gen_elapsed * 1000.0;")?;
6990 writeln!(code)?;
6991 writeln!(
6992 code,
6993 " // Final chunk with finish_reason, timing, and DONE sentinel"
6994 )?;
6995 writeln!(code, " sse_body.push_str(&format!(")?;
6996 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6997 writeln!(code, " prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6998 writeln!(code, " ));")?;
6999 writeln!(code)?;
7000 writeln!(
7001 code,
7002 " let resp = tiny_http::Response::from_string(sse_body)"
7003 )?;
7004 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
7005 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
7006 writeln!(code, " request.respond(resp).ok();")?;
7007
7008 writeln!(code, " }} else {{")?;
7009
7010 writeln!(
7012 code,
7013 " // Non-streaming: generate all tokens, return JSON"
7014 )?;
7015 writeln!(code, " let gen_start = Instant::now();")?;
7016 writeln!(code, " let mut generated_count: usize = 0;")?;
7017 writeln!(code, " let mut generated = String::new();")?;
7018 writeln!(code, " for _ in 0..max_tokens {{")?;
7019 writeln!(code, " if next_token == 2 {{ break; }}")?;
7020 writeln!(
7021 code,
7022 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7023 )?;
7024 writeln!(code, " generated.push_str(&text);")?;
7025 writeln!(code, " }}")?;
7026 writeln!(code, " generated_count += 1;")?;
7027 writeln!(code, " let logits = model.forward(next_token);")?;
7028 writeln!(code, " next_token = argmax(&logits);")?;
7029 writeln!(code, " }}")?;
7030 writeln!(
7031 code,
7032 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7033 )?;
7034 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7035 writeln!(code)?;
7036 writeln!(code, " let resp_json = serde_json::json!({{")?;
7037 writeln!(code, " \"id\": \"chatcmpl-1\",")?;
7038 writeln!(code, " \"object\": \"chat.completion\",")?;
7039 writeln!(code, " \"choices\": [{{")?;
7040 writeln!(code, " \"index\": 0,")?;
7041 writeln!(code, " \"message\": {{")?;
7042 writeln!(code, " \"role\": \"assistant\",")?;
7043 writeln!(code, " \"content\": generated")?;
7044 writeln!(code, " }},")?;
7045 writeln!(code, " \"finish_reason\": \"stop\"")?;
7046 writeln!(code, " }}],")?;
7047 writeln!(code, " \"usage\": {{")?;
7048 writeln!(code, " \"prefill_tokens\": prefill_count,")?;
7049 writeln!(
7050 code,
7051 " \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
7052 )?;
7053 writeln!(
7054 code,
7055 " \"generation_tokens\": generated_count,"
7056 )?;
7057 writeln!(
7058 code,
7059 " \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
7060 )?;
7061 writeln!(code, " \"tokens_per_sec\": gen_tok_s")?;
7062 writeln!(code, " }}")?;
7063 writeln!(code, " }});")?;
7064 writeln!(
7065 code,
7066 " let resp = tiny_http::Response::from_string(resp_json.to_string())"
7067 )?;
7068 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7069 writeln!(code, " request.respond(resp).ok();")?;
7070 writeln!(code, " }}")?;
7071 writeln!(code, "}}")?;
7072
7073 Ok(code)
7074}
7075
7076#[cfg(test)]
7081mod tests {
7082 use super::*;
7083 use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7084
7085 fn minimal_config() -> ModelConfig {
7086 ModelConfig {
7087 architecture: Architecture::Llama,
7088 hidden_size: 64,
7089 intermediate_size: 128,
7090 num_layers: 2,
7091 num_attention_heads: 4,
7092 num_kv_heads: 4,
7093 head_dim: 16,
7094 vocab_size: 256,
7095 max_seq_len: 512,
7096 rms_norm_eps: 1e-5,
7097 rope_theta: 10000.0,
7098 dtype: DType::F32,
7099 sliding_window_size: None,
7100 qkv_bias: false,
7101 }
7102 }
7103
7104 fn minimal_graph() -> Graph {
7105 Graph::new("test-metal").with_config(minimal_config())
7106 }
7107
7108 #[test]
7109 fn generate_metal_project_creates_files() {
7110 let dir = tempfile::tempdir().unwrap();
7111 let graph = minimal_graph();
7112 generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7113
7114 assert!(
7115 dir.path().join("Cargo.toml").exists(),
7116 "Cargo.toml should be created"
7117 );
7118 assert!(
7119 dir.path().join("src/model.rs").exists(),
7120 "src/model.rs should be created"
7121 );
7122 assert!(
7123 dir.path().join("src/main.rs").exists(),
7124 "src/main.rs should be created"
7125 );
7126 assert!(
7127 dir.path().join("shaders/kernels.metal").exists(),
7128 "shaders/kernels.metal should be created"
7129 );
7130 }
7131
7132 #[test]
7133 fn generated_cargo_toml_has_metal_dep() {
7134 let toml = generate_cargo_toml("my-model");
7135 assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7136 assert!(
7137 toml.contains("tokenizers"),
7138 "Cargo.toml should depend on tokenizers"
7139 );
7140 assert!(
7141 toml.contains("memmap2"),
7142 "Cargo.toml should depend on memmap2"
7143 );
7144 assert!(toml.contains("half"), "Cargo.toml should depend on half");
7145 }
7146
7147 #[test]
7148 fn generated_model_rs_contains_metal_code() {
7149 let config = minimal_config();
7150 let model_rs = generate_model_rs(&config).unwrap();
7151
7152 assert!(
7153 model_rs.contains("pub struct MetalModel"),
7154 "model.rs should define MetalModel struct"
7155 );
7156 assert!(
7157 model_rs.contains("matmul_pipeline: ComputePipelineState"),
7158 "MetalModel should have matmul_pipeline field"
7159 );
7160 assert!(
7161 model_rs.contains("Device::system_default()"),
7162 "model.rs should use Metal device"
7163 );
7164 assert!(
7165 model_rs.contains("new_library_with_source"),
7166 "model.rs should compile Metal shaders"
7167 );
7168 assert!(
7169 model_rs.contains("fn new(weights: &[u8])"),
7170 "MetalModel should implement new()"
7171 );
7172 assert!(
7173 model_rs.contains("fn forward(&mut self, token_id: u32)"),
7174 "MetalModel should implement forward()"
7175 );
7176 }
7177
7178 #[test]
7179 fn generated_shaders_contain_kernel_names() {
7180 let shaders = generate_metal_shaders(&minimal_config());
7181
7182 assert!(
7183 shaders.contains("kernel void matmul_vec"),
7184 "shaders should contain matmul_vec kernel"
7185 );
7186 assert!(
7187 shaders.contains("kernel void rms_norm"),
7188 "shaders should contain rms_norm kernel"
7189 );
7190 assert!(
7191 shaders.contains("kernel void rope"),
7192 "shaders should contain rope kernel"
7193 );
7194 assert!(
7195 shaders.contains("kernel void softmax"),
7196 "shaders should contain softmax kernel"
7197 );
7198 assert!(
7199 shaders.contains("kernel void silu_mul("),
7200 "shaders should contain silu_mul kernel"
7201 );
7202 assert!(
7203 shaders.contains("kernel void silu_mul_fused"),
7204 "shaders should contain silu_mul_fused kernel"
7205 );
7206 assert!(
7207 shaders.contains("kernel void elementwise_add"),
7208 "shaders should contain elementwise_add kernel"
7209 );
7210 assert!(
7211 shaders.contains("kernel void attention"),
7212 "shaders should contain attention kernel"
7213 );
7214 assert!(
7215 shaders.contains("kernel void add_inplace"),
7216 "shaders should contain add_inplace kernel"
7217 );
7218 assert!(
7219 shaders.contains("kernel void copy_buffer"),
7220 "shaders should contain copy_buffer kernel"
7221 );
7222 assert!(
7223 shaders.contains("kernel void copy_offset"),
7224 "shaders should contain copy_offset kernel"
7225 );
7226 }
7227
7228 #[test]
7229 fn generated_shaders_use_simdgroup_features() {
7230 let shaders = generate_metal_shaders(&minimal_config());
7231
7232 assert!(
7233 shaders.contains("threadgroup_barrier"),
7234 "shaders should use threadgroup barriers"
7235 );
7236 assert!(
7237 shaders.contains("threadgroup float"),
7238 "shaders should use threadgroup shared memory"
7239 );
7240 assert!(
7241 shaders.contains("thread_index_in_threadgroup"),
7242 "shaders should use threadgroup indexing"
7243 );
7244 assert!(
7245 shaders.contains("simd_sum"),
7246 "shaders should use simd_sum for warp-level reduction"
7247 );
7248 assert!(
7249 shaders.contains("simd_max"),
7250 "attention kernel should use simd_max for cooperative softmax"
7251 );
7252 assert!(
7253 shaders.contains("thread_index_in_simdgroup"),
7254 "shaders should use simdgroup lane indexing"
7255 );
7256 assert!(
7257 shaders.contains("simdgroup_index_in_threadgroup"),
7258 "shaders should use simdgroup indexing within threadgroup"
7259 );
7260 assert!(
7261 shaders.contains("float4"),
7262 "matmul_vec should use float4 vectorized loads"
7263 );
7264 }
7265
7266 #[test]
7267 fn generated_main_rs_has_tokenizer_usage() {
7268 let config = minimal_config();
7269 let main_rs = generate_main_rs("test-model", &config).unwrap();
7270
7271 assert!(
7272 main_rs.contains("tokenizers::Tokenizer"),
7273 "main.rs should use tokenizers crate"
7274 );
7275 assert!(
7276 main_rs.contains("MetalModel::new"),
7277 "main.rs should call MetalModel::new"
7278 );
7279 assert!(
7280 main_rs.contains("model.forward"),
7281 "main.rs should call model.forward"
7282 );
7283 assert!(
7284 main_rs.contains("memmap2"),
7285 "main.rs should use memmap2 for zero-copy weight loading"
7286 );
7287 }
7288
7289 #[test]
7290 fn missing_config_returns_error() {
7291 let dir = tempfile::tempdir().unwrap();
7292 let graph = Graph::new("no-config");
7293 let result = generate_metal_project(&graph, dir.path(), "fail");
7294 assert!(
7295 matches!(result, Err(MetalCodegenError::MissingConfig)),
7296 "should fail with MissingConfig when graph has no config"
7297 );
7298 }
7299
7300 #[test]
7301 fn sanitize_name_works() {
7302 assert_eq!(sanitize_name("My Model!"), "my-model");
7303 assert_eq!(sanitize_name("test_model"), "test-model");
7304 assert_eq!(sanitize_name("simple"), "simple");
7305 }
7306
7307 #[test]
7308 fn generated_forward_uses_single_command_buffer() {
7309 let config = minimal_config();
7310 let model_rs = generate_model_rs(&config).unwrap();
7311
7312 let forward_start = model_rs
7315 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7316 .unwrap();
7317 let forward_body = &model_rs[forward_start..];
7318 let forward_end = forward_body
7320 .find("\n pub fn forward_profile")
7321 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7322 .or_else(|| forward_body.find("\n fn dispatch_"))
7323 .unwrap_or(forward_body.len());
7324 let forward_code = &forward_body[..forward_end];
7325
7326 let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7328 assert_eq!(
7329 cmd_buf_count, 1,
7330 "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7331 );
7332
7333 let commit_count = forward_code.matches("cmd.commit()").count();
7335 assert_eq!(
7336 commit_count, 1,
7337 "forward() should commit exactly once, found {commit_count}"
7338 );
7339
7340 let wait_count = forward_code.matches("wait_until_completed()").count();
7342 assert!(
7343 wait_count >= 1 && wait_count <= 2,
7344 "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7345 );
7346 }
7347
7348 #[test]
7349 fn generated_model_has_preallocated_working_buffers() {
7350 let config = minimal_config();
7351 let model_rs = generate_model_rs(&config).unwrap();
7352
7353 for buf_name in &[
7354 "normed_buf",
7355 "qkv_buf",
7356 "attn_out_buf",
7357 "attn_proj_buf",
7358 "gate_up_buf",
7359 "ffn_hidden_buf",
7360 "ffn_out_buf",
7361 "add_tmp_buf",
7362 ] {
7363 assert!(
7364 model_rs.contains(&format!("{buf_name}: Buffer")),
7365 "MetalModel should have pre-allocated {buf_name} field"
7366 );
7367 }
7368 }
7369
7370 #[test]
7371 fn generated_dispatch_helpers_take_compute_encoder_ref() {
7372 let config = minimal_config();
7373 let model_rs = generate_model_rs(&config).unwrap();
7374
7375 for method in &[
7376 "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7377 "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7378 "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7379 "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7380 "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7381 "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7382 "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7383 "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7384 "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7385 "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7386 "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7387 "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7388 "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7389 ] {
7390 assert!(
7391 model_rs.contains(method),
7392 "model.rs should contain dispatch helper: {method}"
7393 );
7394 }
7395 }
7396
7397 #[test]
7398 fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7399 let config = minimal_config();
7400 let model_rs = generate_model_rs(&config).unwrap();
7401
7402 let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7404 let helpers_code = &model_rs[helpers_start..];
7405
7406 assert!(
7408 !helpers_code.contains("self.queue.new_command_buffer()"),
7409 "dispatch helpers should not create their own command buffers"
7410 );
7411
7412 assert!(
7414 !helpers_code.contains("new_compute_command_encoder()"),
7415 "dispatch helpers should not create their own compute encoders"
7416 );
7417
7418 assert!(
7420 !helpers_code.contains("end_encoding()"),
7421 "dispatch helpers should not call end_encoding"
7422 );
7423
7424 assert!(
7426 !helpers_code.contains(".commit()"),
7427 "dispatch helpers should not commit command buffers"
7428 );
7429 assert!(
7430 !helpers_code.contains("wait_until_completed"),
7431 "dispatch helpers should not wait on command buffers"
7432 );
7433 }
7434
7435 #[test]
7436 fn generated_forward_batches_compute_encoders() {
7437 let config = minimal_config();
7438 let model_rs = generate_model_rs(&config).unwrap();
7439
7440 let forward_start = model_rs
7442 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7443 .unwrap();
7444 let forward_body = &model_rs[forward_start..];
7445 let forward_end = forward_body
7446 .find("\n pub fn forward_profile")
7447 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7448 .or_else(|| forward_body.find("\n fn dispatch_"))
7449 .unwrap_or(forward_body.len());
7450 let forward_code = &forward_body[..forward_end];
7451
7452 assert!(
7454 !forward_code.contains("device.new_buffer"),
7455 "forward() should not allocate new buffers per call"
7456 );
7457
7458 let compute_encoder_count = forward_code
7461 .matches("new_compute_command_encoder()")
7462 .count();
7463 let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7464
7465 assert_eq!(
7467 compute_encoder_count, 1,
7468 "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7469 );
7470 assert_eq!(
7471 blit_encoder_count, 0,
7472 "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7473 );
7474 }
7475
7476 #[test]
7477 fn generated_forward_uses_add_inplace() {
7478 let config = minimal_config();
7479 let model_rs = generate_model_rs(&config).unwrap();
7480
7481 assert!(
7483 model_rs.contains("dispatch_add_inplace"),
7484 "forward() should use dispatch_add_inplace for residual connections"
7485 );
7486 assert!(
7487 model_rs.contains("add_inplace_pipeline"),
7488 "MetalModel should have add_inplace_pipeline"
7489 );
7490 }
7491
7492 fn minimal_q8_config() -> ModelConfig {
7493 ModelConfig {
7494 architecture: Architecture::Llama,
7495 hidden_size: 64,
7496 intermediate_size: 128,
7497 num_layers: 2,
7498 num_attention_heads: 4,
7499 num_kv_heads: 4,
7500 head_dim: 16,
7501 vocab_size: 256,
7502 max_seq_len: 512,
7503 rms_norm_eps: 1e-5,
7504 rope_theta: 10000.0,
7505 dtype: DType::Q8_0,
7506 sliding_window_size: None,
7507 qkv_bias: false,
7508 }
7509 }
7510
7511 #[test]
7512 fn generated_shaders_contain_q8_kernel() {
7513 let shaders = generate_metal_shaders(&minimal_config());
7514
7515 assert!(
7516 shaders.contains("kernel void matmul_vec_q8"),
7517 "shaders should contain matmul_vec_q8 kernel"
7518 );
7519 assert!(
7520 shaders.contains("device const uchar* matrix"),
7521 "matmul_vec_q8 should accept raw Q8_0 bytes"
7522 );
7523 assert!(
7524 shaders.contains("packed_short4"),
7525 "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7526 );
7527 assert!(
7528 shaders.contains("as_type<char2>"),
7529 "matmul_vec_q8 should bitcast short lanes to char2"
7530 );
7531 assert!(
7532 shaders.contains("device const half*"),
7533 "matmul_vec_q8 should read f16 scale via half pointer"
7534 );
7535 }
7536
7537 #[test]
7538 fn generated_model_uses_fused_qkv_projections() {
7539 let config = minimal_config();
7540 let model_rs = generate_model_rs(&config).unwrap();
7541
7542 assert!(
7544 model_rs.contains("qkv_weight: Buffer"),
7545 "LayerBuffers should have fused qkv_weight field"
7546 );
7547 assert!(
7549 !model_rs.contains(" q_weight: Buffer"),
7550 "LayerBuffers should not have separate q_weight field"
7551 );
7552 assert!(
7553 !model_rs.contains(" k_weight: Buffer"),
7554 "LayerBuffers should not have separate k_weight field"
7555 );
7556 assert!(
7557 !model_rs.contains(" v_weight: Buffer"),
7558 "LayerBuffers should not have separate v_weight field"
7559 );
7560
7561 assert!(
7563 model_rs.contains("gate_up_weight: Buffer"),
7564 "LayerBuffers should have fused gate_up_weight field"
7565 );
7566 assert!(
7568 !model_rs.contains(" gate_weight: Buffer"),
7569 "LayerBuffers should not have separate gate_weight field"
7570 );
7571 assert!(
7572 !model_rs.contains(" up_weight: Buffer"),
7573 "LayerBuffers should not have separate up_weight field"
7574 );
7575
7576 assert!(
7578 model_rs.contains("qkv_buf: Buffer"),
7579 "MetalModel should have fused qkv_buf"
7580 );
7581 assert!(
7582 model_rs.contains("gate_up_buf: Buffer"),
7583 "MetalModel should have fused gate_up_buf"
7584 );
7585
7586 assert!(
7588 model_rs.contains("dispatch_silu_mul_fused"),
7589 "forward pass should use dispatch_silu_mul_fused"
7590 );
7591 assert!(
7592 model_rs.contains("dispatch_rope_offset"),
7593 "forward pass should use dispatch_rope_offset for fused QKV"
7594 );
7595 assert!(
7596 model_rs.contains("dispatch_attention_offset"),
7597 "forward pass should use dispatch_attention_offset for fused QKV"
7598 );
7599 }
7600
7601 #[test]
7602 fn q8_model_has_matmul_q8_pipeline() {
7603 let config = minimal_q8_config();
7604 let model_rs = generate_model_rs(&config).unwrap();
7605
7606 assert!(
7607 model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7608 "MetalModel should have matmul_q8_pipeline field"
7609 );
7610 assert!(
7611 model_rs.contains("matmul_q8_pipeline,"),
7612 "MetalModel Self should include matmul_q8_pipeline"
7613 );
7614 }
7615
7616 #[test]
7617 fn q8_model_uses_dispatch_matmul_q8() {
7618 let config = minimal_q8_config();
7619 let model_rs = generate_model_rs(&config).unwrap();
7620
7621 assert!(
7622 model_rs.contains("dispatch_matmul_q8"),
7623 "Q8_0 model should use dispatch_matmul_q8 for projections"
7624 );
7625 assert!(
7626 model_rs.contains("fn dispatch_matmul_q8"),
7627 "model.rs should define dispatch_matmul_q8 method"
7628 );
7629 }
7630
7631 #[test]
7632 fn q8_model_loads_raw_bytes_not_dequantized() {
7633 let config = minimal_q8_config();
7634 let model_rs = generate_model_rs(&config).unwrap();
7635
7636 assert!(
7638 !model_rs.contains("f16_to_f32"),
7639 "Q8_0 model should not dequantize weights to f32"
7640 );
7641 assert!(
7642 !model_rs.contains("f32_data"),
7643 "Q8_0 model should not create f32 weight data"
7644 );
7645
7646 assert!(
7648 model_rs.contains("total_raw as u64"),
7649 "Q8_0 model should load raw bytes into Metal buffer"
7650 );
7651 }
7652
7653 #[test]
7654 fn q8_model_norms_stay_f32() {
7655 let config = minimal_q8_config();
7656 let model_rs = generate_model_rs(&config).unwrap();
7657
7658 assert!(
7660 model_rs.contains("let attn_norm = next_f32_buffer"),
7661 "attn_norm should use f32 buffer even for Q8_0 models"
7662 );
7663 assert!(
7664 model_rs.contains("let ffn_norm = next_f32_buffer"),
7665 "ffn_norm should use f32 buffer even for Q8_0 models"
7666 );
7667 assert!(
7668 model_rs.contains("let norm_buf = next_f32_buffer"),
7669 "final norm should use f32 buffer even for Q8_0 models"
7670 );
7671 }
7672
7673 #[test]
7674 fn q8_model_uses_fused_weight_loading() {
7675 let config = minimal_q8_config();
7676 let model_rs = generate_model_rs(&config).unwrap();
7677
7678 assert!(
7680 model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7681 "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7682 );
7683 assert!(
7685 model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7686 "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7687 );
7688 assert!(
7690 model_rs.contains("let o_weight = next_q8_buffer"),
7691 "Q8_0 model should use next_q8_buffer for O weight"
7692 );
7693 assert!(
7694 model_rs.contains("let down_weight = next_q8_buffer"),
7695 "Q8_0 model should use next_q8_buffer for down weight"
7696 );
7697 }
7698
7699 #[test]
7700 fn f32_model_does_not_use_q8_dispatch() {
7701 let config = minimal_config();
7702 let model_rs = generate_model_rs(&config).unwrap();
7703
7704 let forward_start = model_rs
7706 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7707 .unwrap();
7708 let forward_body = &model_rs[forward_start..];
7709 let forward_end = forward_body
7710 .find("\n fn dispatch_")
7711 .unwrap_or(forward_body.len());
7712 let forward_code = &forward_body[..forward_end];
7713
7714 assert!(
7715 !forward_code.contains("dispatch_matmul_q8"),
7716 "f32 model forward should not use dispatch_matmul_q8"
7717 );
7718 }
7719
7720 #[test]
7721 fn q8_dispatch_helper_takes_compute_encoder_ref() {
7722 let config = minimal_q8_config();
7723 let model_rs = generate_model_rs(&config).unwrap();
7724
7725 assert!(
7726 model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7727 "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7728 );
7729 }
7730
7731 #[test]
7732 fn generated_model_has_double_buffered_prefill() {
7733 let config = minimal_config();
7734 let model_rs = generate_model_rs(&config).unwrap();
7735
7736 assert!(
7738 model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7739 "MetalModel should have prev_cmd field for double-buffered prefill"
7740 );
7741
7742 assert!(
7744 model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7745 "MetalModel should have forward_prefill method"
7746 );
7747
7748 assert!(
7750 model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7751 "forward() should drain prev_cmd from previous prefill"
7752 );
7753 }
7754
7755 #[test]
7756 fn generated_main_rs_uses_forward_prefill_for_prompt() {
7757 let config = minimal_config();
7758 let main_rs = generate_main_rs("test-model", &config).unwrap();
7759
7760 assert!(
7761 main_rs.contains("forward_prefill"),
7762 "main.rs should use forward_prefill for intermediate prompt tokens"
7763 );
7764 assert!(
7765 main_rs.contains("double-buffered"),
7766 "main.rs should document double-buffered prefill"
7767 );
7768 }
7769
7770 #[test]
7771 fn generated_shaders_q8_uses_wide_vectorized_loads() {
7772 let shaders = generate_metal_shaders(&minimal_config());
7773
7774 assert!(
7775 shaders.contains("packed_short4"),
7776 "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7777 );
7778 assert!(
7779 shaders.contains("d0[0]"),
7780 "matmul_vec_q8 should index the wide pointer for row 0"
7781 );
7782 assert!(
7783 shaders.contains("as_type<char2>"),
7784 "matmul_vec_q8 should bitcast short lanes to char2"
7785 );
7786 assert!(
7787 shaders.contains("dot("),
7788 "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7789 );
7790 }
7791
7792 fn minimal_q4_config() -> ModelConfig {
7795 ModelConfig {
7796 architecture: Architecture::Llama,
7797 hidden_size: 64,
7798 intermediate_size: 128,
7799 num_layers: 2,
7800 num_attention_heads: 4,
7801 num_kv_heads: 4,
7802 head_dim: 16,
7803 vocab_size: 256,
7804 max_seq_len: 512,
7805 rms_norm_eps: 1e-5,
7806 rope_theta: 10000.0,
7807 dtype: DType::Q4_0,
7808 sliding_window_size: None,
7809 qkv_bias: false,
7810 }
7811 }
7812
7813 #[test]
7814 fn generated_shaders_contain_q4_kernel() {
7815 let shaders = generate_metal_shaders(&minimal_config());
7816
7817 assert!(
7818 shaders.contains("kernel void matmul_vec_q4"),
7819 "shaders should contain matmul_vec_q4 kernel"
7820 );
7821 assert!(
7822 shaders.contains("Q4_ROWS_PER_TG"),
7823 "shaders should define Q4_ROWS_PER_TG constant"
7824 );
7825 assert!(
7826 shaders.contains("Q4_ROWS_PER_SG"),
7827 "shaders should define Q4_ROWS_PER_SG constant"
7828 );
7829 }
7830
7831 #[test]
7832 fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7833 let shaders = generate_metal_shaders(&minimal_config());
7834
7835 assert!(
7837 shaders.contains("uchar4"),
7838 "matmul_vec_q4 should use uchar4 for packed byte loads"
7839 );
7840 assert!(
7842 shaders.contains("&0xF"),
7843 "matmul_vec_q4 should extract low nibble with &0xF"
7844 );
7845 assert!(
7846 shaders.contains(">>4"),
7847 "matmul_vec_q4 should extract high nibble with >>4"
7848 );
7849 assert!(
7851 shaders.contains("-8)"),
7852 "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7853 );
7854 assert!(
7856 shaders.contains("blk * 18"),
7857 "matmul_vec_q4 should use 18-byte block stride"
7858 );
7859 }
7860
7861 #[test]
7862 fn q4_model_has_matmul_q4_pipeline() {
7863 let config = minimal_q4_config();
7864 let model_rs = generate_model_rs(&config).unwrap();
7865
7866 assert!(
7867 model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7868 "MetalModel should have matmul_q4_pipeline field"
7869 );
7870 assert!(
7871 model_rs.contains("matmul_q4_pipeline,"),
7872 "MetalModel Self should include matmul_q4_pipeline"
7873 );
7874 }
7875
7876 #[test]
7877 fn q4_model_uses_dispatch_matmul_q4() {
7878 let config = minimal_q4_config();
7879 let model_rs = generate_model_rs(&config).unwrap();
7880
7881 assert!(
7882 model_rs.contains("dispatch_matmul_q4"),
7883 "Q4_0 model should use dispatch_matmul_q4 for projections"
7884 );
7885 assert!(
7886 model_rs.contains("fn dispatch_matmul_q4"),
7887 "model.rs should define dispatch_matmul_q4 method"
7888 );
7889 }
7890
7891 #[test]
7892 fn q4_model_loads_raw_bytes_not_dequantized() {
7893 let config = minimal_q4_config();
7894 let model_rs = generate_model_rs(&config).unwrap();
7895
7896 assert!(
7898 !model_rs.contains("f16_to_f32"),
7899 "Q4_0 model should not dequantize weights to f32"
7900 );
7901
7902 assert!(
7904 model_rs.contains("total_raw as u64"),
7905 "Q4_0 model should load raw bytes into Metal buffer"
7906 );
7907 }
7908
7909 #[test]
7910 fn q4_model_norms_stay_f32() {
7911 let config = minimal_q4_config();
7912 let model_rs = generate_model_rs(&config).unwrap();
7913
7914 assert!(
7915 model_rs.contains("let attn_norm = next_f32_buffer"),
7916 "attn_norm should use f32 buffer even for Q4_0 models"
7917 );
7918 assert!(
7919 model_rs.contains("let ffn_norm = next_f32_buffer"),
7920 "ffn_norm should use f32 buffer even for Q4_0 models"
7921 );
7922 assert!(
7923 model_rs.contains("let norm_buf = next_f32_buffer"),
7924 "final norm should use f32 buffer even for Q4_0 models"
7925 );
7926 }
7927
7928 #[test]
7929 fn q4_model_uses_fused_weight_loading() {
7930 let config = minimal_q4_config();
7931 let model_rs = generate_model_rs(&config).unwrap();
7932
7933 assert!(
7934 model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7935 "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7936 );
7937 assert!(
7938 model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7939 "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7940 );
7941 assert!(
7942 model_rs.contains("let o_weight = next_q4_buffer"),
7943 "Q4_0 model should use next_q4_buffer for O weight"
7944 );
7945 assert!(
7946 model_rs.contains("let down_weight = next_q4_buffer"),
7947 "Q4_0 model should use next_q4_buffer for down weight"
7948 );
7949 }
7950
7951 #[test]
7952 fn attention_flash_batch_kernel_exists() {
7953 let config = minimal_config();
7957 let model_rs = generate_model_rs(&config).unwrap();
7958 let shaders = generate_metal_shaders(&config);
7959
7960 assert!(
7961 shaders.contains("kernel void attention_flash_batch"),
7962 "shaders.metal must still contain the attention_flash_batch kernel"
7963 );
7964 assert!(
7965 shaders.contains("FLASH_K_TILE"),
7966 "flash kernel must tile K/V with a FLASH_K_TILE constant"
7967 );
7968 assert!(
7969 model_rs.contains("attention_flash_batch_pipeline"),
7970 "MetalModel must register the flash attention pipeline"
7971 );
7972 }
7973
7974 #[test]
7975 fn decode_uses_fused_rope_and_kv_copy() {
7976 let config = minimal_config();
7981 let model_rs = generate_model_rs(&config).unwrap();
7982
7983 let forward_start = model_rs.find("pub fn forward(").expect("forward missing");
7988 let forward_end = model_rs[forward_start..]
7989 .find("pub fn forward_prefill(")
7990 .expect("forward_prefill missing");
7991 let forward_body = &model_rs[forward_start..forward_start + forward_end];
7992
7993 assert!(
7994 forward_body.contains("dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1"),
7995 "single-token forward must use fused rope_qk_batch with M=1"
7996 );
7997 assert!(
7998 forward_body.contains("dispatch_copy_kv_both_batch(&enc, &self.qkv_buf,"),
7999 "single-token forward must use fused copy_kv_both_batch"
8000 );
8001 assert!(
8002 !forward_body.contains("dispatch_copy_from_offset_f16"),
8003 "decode path should no longer call the per-K/per-V copy"
8004 );
8005 }
8006
8007 #[test]
8008 fn kv_cache_stored_as_f16() {
8009 let config = minimal_config();
8015 let shaders = generate_metal_shaders(&config);
8016 let model_rs = generate_model_rs(&config).unwrap();
8017
8018 for kernel in [
8019 "kernel void attention(",
8020 "kernel void attention_batch(",
8021 "kernel void attention_flash_batch(",
8022 "kernel void attention_mma_flash_batch(",
8023 ] {
8024 let start = shaders
8025 .find(kernel)
8026 .unwrap_or_else(|| panic!("kernel {kernel} missing"));
8027 let sig_end = shaders[start..]
8030 .find("){")
8031 .unwrap_or_else(|| shaders[start..].find(") {").unwrap());
8032 let sig = &shaders[start..start + sig_end];
8033 assert!(
8034 sig.contains("device const half*")
8035 && sig.contains("k_cache")
8036 && sig.contains("v_cache"),
8037 "{kernel} must read k_cache/v_cache as `device const half*`"
8038 );
8039 assert!(
8040 !sig.contains("device const float* k_cache")
8041 && !sig.contains("device const float* k_cache"),
8042 "{kernel} still reads k_cache as float"
8043 );
8044 }
8045
8046 assert!(
8047 shaders.contains("kernel void copy_f32_to_f16_offset"),
8048 "f32->f16 copy kernel must be present for single-token decode KV writes"
8049 );
8050 assert!(
8051 model_rs.contains("dispatch_copy_from_offset_f16"),
8052 "single-token decode must dispatch the f32->f16 KV copy"
8053 );
8054 }
8055
8056 #[test]
8057 fn decode_attention_uses_half4_vectorized_loads() {
8058 let config = minimal_config();
8062 let shaders = generate_metal_shaders(&config);
8063
8064 let start = shaders
8065 .find("kernel void attention(")
8066 .expect("decode attention kernel missing");
8067 let end_rel = shaders[start + 1..]
8069 .find("kernel void ")
8070 .expect("next kernel missing");
8071 let body = &shaders[start..start + 1 + end_rel];
8072
8073 assert!(
8074 body.contains("device const half4*"),
8075 "decode attention must half4-load K/V"
8076 );
8077 assert!(
8078 body.contains("device const float4*"),
8079 "decode attention must float4-load Q"
8080 );
8081 assert!(
8082 body.contains("device float4*"),
8083 "decode attention must float4-store output"
8084 );
8085 assert!(
8086 body.contains("head_dim4"),
8087 "decode attention must iterate head_dim in chunks of 4"
8088 );
8089 }
8090
8091 #[test]
8092 fn attention_mma_flash_batch_kernel_wired() {
8093 let config = minimal_config();
8096 let model_rs = generate_model_rs(&config).unwrap();
8097 let shaders = generate_metal_shaders(&config);
8098
8099 assert!(
8100 shaders.contains("kernel void attention_mma_flash_batch"),
8101 "shaders.metal must contain the MMA flash kernel"
8102 );
8103 assert!(
8104 shaders.contains("FLASH_MMA_Q_BLOCK"),
8105 "MMA flash kernel must define Q_BLOCK tiling constant"
8106 );
8107 assert!(
8108 shaders.contains("simdgroup_multiply_accumulate"),
8109 "MMA flash kernel must use hardware MMA"
8110 );
8111 assert!(
8112 model_rs.contains("attention_mma_flash_batch_pipeline"),
8113 "MetalModel must register the MMA flash pipeline"
8114 );
8115 assert!(
8116 model_rs.contains("mma_opt_out"),
8117 "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
8118 );
8119 assert!(
8120 model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
8121 "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
8122 );
8123 }
8124
8125 #[test]
8126 fn forward_prefill_batch_chunks_by_max_batch_size() {
8127 let config = minimal_config();
8132 let model_rs = generate_model_rs(&config).unwrap();
8133 assert!(
8134 model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
8135 "forward_prefill_batch must chunk long prompts"
8136 );
8137 assert!(
8138 !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
8139 "the old truncation path must be gone"
8140 );
8141 }
8142
8143 #[test]
8144 fn qwen2_qkv_bias_wired_through_metal_codegen() {
8145 let config = ModelConfig {
8149 architecture: Architecture::Qwen2,
8150 qkv_bias: true,
8151 ..minimal_config()
8152 };
8153 let model_rs = generate_model_rs(&config).unwrap();
8154
8155 assert!(
8156 model_rs.contains("qkv_bias: Buffer"),
8157 "Qwen2 LayerBuffers must declare qkv_bias field"
8158 );
8159 assert!(
8160 model_rs.contains("let qkv_bias = next_f32_buffer"),
8161 "Qwen2 layer init must load the bias from the weight blob"
8162 );
8163 assert!(
8164 model_rs.contains("add_bias_batch_pipeline"),
8165 "Qwen2 model struct must include the add_bias_batch_pipeline"
8166 );
8167 assert!(
8168 model_rs.contains("fn dispatch_add_bias_batch"),
8169 "Qwen2 codegen must emit dispatch_add_bias_batch helper"
8170 );
8171 assert!(
8172 model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
8173 "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
8174 );
8175 assert!(
8176 model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
8177 "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
8178 );
8179
8180 let shaders = generate_metal_shaders(&config);
8182 assert!(
8183 shaders.contains("kernel void add_bias_batch"),
8184 "shaders.metal must contain the add_bias_batch kernel"
8185 );
8186 }
8187
8188 #[test]
8189 fn llama_does_not_emit_qkv_bias_machinery() {
8190 let config = minimal_config();
8193 assert!(!config.qkv_bias);
8194 let model_rs = generate_model_rs(&config).unwrap();
8195 assert!(
8196 !model_rs.contains("qkv_bias: Buffer"),
8197 "Llama must not have qkv_bias field"
8198 );
8199 assert!(
8200 !model_rs.contains("add_bias_batch_pipeline"),
8201 "Llama must not pull in add_bias_batch_pipeline"
8202 );
8203 assert!(
8204 !model_rs.contains("dispatch_add_bias_batch"),
8205 "Llama must not call dispatch_add_bias_batch"
8206 );
8207 }
8208
8209 #[test]
8210 fn q4_dispatch_helper_takes_compute_encoder_ref() {
8211 let config = minimal_q4_config();
8212 let model_rs = generate_model_rs(&config).unwrap();
8213
8214 assert!(
8215 model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8216 "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8217 );
8218 }
8219
8220 #[test]
8221 fn f32_model_does_not_use_q4_dispatch() {
8222 let config = minimal_config();
8223 let model_rs = generate_model_rs(&config).unwrap();
8224
8225 let forward_start = model_rs
8226 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8227 .unwrap();
8228 let forward_body = &model_rs[forward_start..];
8229 let forward_end = forward_body
8230 .find("\n fn dispatch_")
8231 .unwrap_or(forward_body.len());
8232 let forward_code = &forward_body[..forward_end];
8233
8234 assert!(
8235 !forward_code.contains("dispatch_matmul_q4"),
8236 "f32 model forward should not use dispatch_matmul_q4"
8237 );
8238 }
8239
8240 #[test]
8241 fn q4_model_lm_head_uses_q4_buffer() {
8242 let config = minimal_q4_config();
8243 let model_rs = generate_model_rs(&config).unwrap();
8244
8245 assert!(
8246 model_rs.contains("let lm_head_buf = next_q4_buffer"),
8247 "Q4_0 model should use next_q4_buffer for lm_head"
8248 );
8249 }
8250
8251 #[test]
8252 fn vec_tile_size_matches_model_dimensions() {
8253 let small = minimal_config();
8255 let shaders_small = generate_metal_shaders(&small);
8256 assert!(
8257 shaders_small.contains("vec_tile[128]"),
8258 "vec_tile should be sized to max(hidden, intermediate) = 128"
8259 );
8260
8261 let mut large = minimal_config();
8263 large.hidden_size = 2048;
8264 large.intermediate_size = 8192;
8265 let shaders_large = generate_metal_shaders(&large);
8266 assert!(
8267 shaders_large.contains("vec_tile[8192]"),
8268 "vec_tile should be 8192 for models with intermediate=8192"
8269 );
8270 assert!(
8271 !shaders_large.contains("vec_tile[4096]"),
8272 "vec_tile should NOT be hardcoded to 4096"
8273 );
8274 }
8275
8276 #[test]
8277 fn generated_cargo_toml_has_server_deps() {
8278 let toml = generate_cargo_toml("my-model");
8279 assert!(
8280 toml.contains("tiny_http"),
8281 "Cargo.toml should depend on tiny_http"
8282 );
8283 assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8284 assert!(
8285 toml.contains("serde_json"),
8286 "Cargo.toml should depend on serde_json"
8287 );
8288 }
8289
8290 #[test]
8291 fn generated_main_rs_has_serve_mode() {
8292 let config = minimal_config();
8293 let main_rs = generate_main_rs("test-model", &config).unwrap();
8294
8295 assert!(
8296 main_rs.contains("--serve"),
8297 "main.rs should parse --serve flag"
8298 );
8299 assert!(
8300 main_rs.contains("--port"),
8301 "main.rs should parse --port flag"
8302 );
8303 assert!(
8304 main_rs.contains("fn serve("),
8305 "main.rs should define serve function"
8306 );
8307 assert!(
8308 main_rs.contains("tiny_http::Server::http"),
8309 "main.rs should create tiny_http server"
8310 );
8311 }
8312
8313 #[test]
8314 fn generated_main_rs_has_chat_completions_endpoint() {
8315 let config = minimal_config();
8316 let main_rs = generate_main_rs("test-model", &config).unwrap();
8317
8318 assert!(
8319 main_rs.contains("/v1/chat/completions"),
8320 "main.rs should handle /v1/chat/completions endpoint"
8321 );
8322 assert!(
8323 main_rs.contains("/v1/models"),
8324 "main.rs should handle /v1/models endpoint"
8325 );
8326 assert!(
8327 main_rs.contains("/health"),
8328 "main.rs should handle /health endpoint"
8329 );
8330 }
8331
8332 #[test]
8333 fn generated_main_rs_has_sse_streaming() {
8334 let config = minimal_config();
8335 let main_rs = generate_main_rs("test-model", &config).unwrap();
8336
8337 assert!(
8338 main_rs.contains("text/event-stream"),
8339 "main.rs should set SSE content type for streaming"
8340 );
8341 assert!(
8342 main_rs.contains("chat.completion.chunk"),
8343 "main.rs should emit SSE chunks"
8344 );
8345 assert!(
8346 main_rs.contains("[DONE]"),
8347 "main.rs should emit [DONE] sentinel"
8348 );
8349 }
8350
8351 #[test]
8352 fn generated_main_rs_has_chat_message_formatting() {
8353 let config = minimal_config();
8354 let main_rs = generate_main_rs("test-model", &config).unwrap();
8355
8356 assert!(
8357 main_rs.contains("fn format_chat_messages"),
8358 "main.rs should define format_chat_messages function"
8359 );
8360 assert!(
8361 main_rs.contains("<|im_start|>"),
8362 "main.rs should use ChatML format"
8363 );
8364 assert!(
8365 main_rs.contains("<|im_end|>"),
8366 "main.rs should use ChatML format"
8367 );
8368 }
8369
8370 #[test]
8371 fn generated_main_rs_has_request_types() {
8372 let config = minimal_config();
8373 let main_rs = generate_main_rs("test-model", &config).unwrap();
8374
8375 assert!(
8376 main_rs.contains("struct ChatRequest"),
8377 "main.rs should define ChatRequest struct"
8378 );
8379 assert!(
8380 main_rs.contains("struct ChatMessage"),
8381 "main.rs should define ChatMessage struct"
8382 );
8383 assert!(
8384 main_rs.contains("Deserialize"),
8385 "main.rs should derive Deserialize for request types"
8386 );
8387 }
8388
8389 #[test]
8390 fn generated_model_has_reset_method() {
8391 let config = minimal_config();
8392 let model_rs = generate_model_rs(&config).unwrap();
8393
8394 assert!(
8395 model_rs.contains("pub fn reset(&mut self)"),
8396 "model.rs should have a reset() method for multi-request serving"
8397 );
8398 assert!(
8399 model_rs.contains("self.pos = 0"),
8400 "reset() should reset position to 0"
8401 );
8402 }
8403
8404 #[test]
8405 fn generated_main_rs_cli_mode_still_works() {
8406 let config = minimal_config();
8407 let main_rs = generate_main_rs("test-model", &config).unwrap();
8408
8409 assert!(
8411 main_rs.contains("fn cli_mode("),
8412 "main.rs should define cli_mode function"
8413 );
8414 assert!(
8415 main_rs.contains("model.forward"),
8416 "main.rs should call model.forward"
8417 );
8418 assert!(
8419 main_rs.contains("model.forward_prefill"),
8420 "main.rs should call model.forward_prefill"
8421 );
8422 }
8423
8424 #[test]
8427 fn generated_shaders_contain_batch_kernels() {
8428 let shaders = generate_metal_shaders(&minimal_config());
8429
8430 assert!(
8431 shaders.contains("kernel void matmul_vec_batch"),
8432 "shaders should contain matmul_vec_batch kernel"
8433 );
8434 assert!(
8435 shaders.contains("kernel void matmul_vec_q8_batch"),
8436 "shaders should contain matmul_vec_q8_batch kernel"
8437 );
8438 assert!(
8439 shaders.contains("kernel void matmul_q8_gemm_batch"),
8440 "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8441 );
8442 assert!(
8443 shaders.contains("kernel void matmul_vec_q4_batch"),
8444 "shaders should contain matmul_vec_q4_batch kernel"
8445 );
8446 assert!(
8447 shaders.contains("kernel void rms_norm_batch"),
8448 "shaders should contain rms_norm_batch kernel"
8449 );
8450 assert!(
8451 shaders.contains("kernel void silu_mul_fused_batch"),
8452 "shaders should contain silu_mul_fused_batch kernel"
8453 );
8454 assert!(
8455 shaders.contains("kernel void add_inplace_batch"),
8456 "shaders should contain add_inplace_batch kernel"
8457 );
8458 assert!(
8459 shaders.contains("kernel void copy_embedding_batch"),
8460 "shaders should contain copy_embedding_batch kernel"
8461 );
8462 }
8463
8464 #[test]
8465 fn generated_model_has_batch_pipelines() {
8466 let config = minimal_config();
8467 let model_rs = generate_model_rs(&config).unwrap();
8468
8469 for pipeline in &[
8470 "matmul_batch_pipeline",
8471 "matmul_q8_batch_pipeline",
8472 "matmul_q8_gemm_batch_pipeline",
8473 "matmul_q4_batch_pipeline",
8474 "rms_norm_batch_pipeline",
8475 "rope_batch_pipeline",
8476 "silu_mul_fused_batch_pipeline",
8477 "add_inplace_batch_pipeline",
8478 "copy_embedding_batch_pipeline",
8479 "attention_batch_pipeline",
8480 "copy_kv_batch_pipeline",
8481 ] {
8482 assert!(
8483 model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8484 "MetalModel should have {pipeline} field"
8485 );
8486 }
8487 }
8488
8489 #[test]
8490 fn generated_model_has_batch_buffers() {
8491 let config = minimal_config();
8492 let model_rs = generate_model_rs(&config).unwrap();
8493
8494 for buf in &[
8495 "batch_hidden_buf",
8496 "batch_residual_buf",
8497 "batch_qkv_buf",
8498 "batch_attn_out_buf",
8499 "batch_attn_proj_buf",
8500 "batch_gate_up_buf",
8501 "batch_ffn_hidden_buf",
8502 "batch_ffn_out_buf",
8503 "batch_tokens_buf",
8504 "batch_positions_buf",
8505 ] {
8506 assert!(
8507 model_rs.contains(&format!("{buf}: Buffer")),
8508 "MetalModel should have {buf} field"
8509 );
8510 }
8511 }
8512
8513 #[test]
8514 fn generated_model_has_forward_prefill_batch() {
8515 let config = minimal_config();
8516 let model_rs = generate_model_rs(&config).unwrap();
8517
8518 assert!(
8519 model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8520 "MetalModel should have forward_prefill_batch method"
8521 );
8522
8523 assert!(
8525 model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8526 "forward_prefill should delegate to forward_prefill_batch"
8527 );
8528 }
8529
8530 #[test]
8531 fn generated_model_has_max_batch_size_constant() {
8532 let config = minimal_config();
8533 let model_rs = generate_model_rs(&config).unwrap();
8534
8535 assert!(
8536 model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8537 "model.rs should define MAX_BATCH_SIZE constant"
8538 );
8539 }
8540
8541 #[test]
8542 fn forward_prefill_batch_uses_batch_dispatch() {
8543 let config = minimal_config();
8544 let model_rs = generate_model_rs(&config).unwrap();
8545
8546 let batch_start = model_rs
8547 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8548 .unwrap();
8549 let batch_body = &model_rs[batch_start..];
8550 let batch_end = batch_body
8551 .find("\n pub fn reset")
8552 .unwrap_or(batch_body.len());
8553 let batch_code = &batch_body[..batch_end];
8554
8555 assert!(
8557 batch_code.contains("dispatch_rms_norm_batch"),
8558 "forward_prefill_batch should use dispatch_rms_norm_batch"
8559 );
8560 assert!(
8561 batch_code.contains("dispatch_copy_embedding_batch"),
8562 "forward_prefill_batch should use dispatch_copy_embedding_batch"
8563 );
8564 assert!(
8565 batch_code.contains("dispatch_silu_mul_fused_batch"),
8566 "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8567 );
8568 assert!(
8570 batch_code.contains("dispatch_attention_batch"),
8571 "forward_prefill_batch should use dispatch_attention_batch"
8572 );
8573 assert!(
8575 batch_code.contains("dispatch_copy_kv_both_batch"),
8576 "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8577 );
8578 assert!(
8580 batch_code.contains("dispatch_rope_qk_batch"),
8581 "forward_prefill_batch should use dispatch_rope_qk_batch"
8582 );
8583 }
8584
8585 #[test]
8586 fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8587 let config = minimal_q8_config();
8588 let model_rs = generate_model_rs(&config).unwrap();
8589
8590 let batch_start = model_rs
8591 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8592 .unwrap();
8593 let batch_body = &model_rs[batch_start..];
8594 let batch_end = batch_body
8595 .find("\n pub fn reset")
8596 .unwrap_or(batch_body.len());
8597 let batch_code = &batch_body[..batch_end];
8598
8599 assert!(
8600 batch_code.contains("dispatch_matmul_q8_batch"),
8601 "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8602 );
8603 }
8604
8605 #[test]
8606 fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8607 let config = minimal_q4_config();
8608 let model_rs = generate_model_rs(&config).unwrap();
8609
8610 let batch_start = model_rs
8611 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8612 .unwrap();
8613 let batch_body = &model_rs[batch_start..];
8614 let batch_end = batch_body
8615 .find("\n pub fn reset")
8616 .unwrap_or(batch_body.len());
8617 let batch_code = &batch_body[..batch_end];
8618
8619 assert!(
8620 batch_code.contains("dispatch_matmul_q4_batch"),
8621 "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8622 );
8623 }
8624
8625 #[test]
8626 fn generated_main_rs_uses_batched_prefill() {
8627 let config = minimal_config();
8628 let main_rs = generate_main_rs("test-model", &config).unwrap();
8629
8630 assert!(
8631 main_rs.contains("forward_prefill_batch"),
8632 "main.rs should use forward_prefill_batch for prompt tokens"
8633 );
8634 }
8635
8636 #[test]
8637 fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8638 let config = minimal_config();
8639 let model_rs = generate_model_rs(&config).unwrap();
8640
8641 let batch_start = model_rs
8642 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8643 .unwrap();
8644 let batch_body = &model_rs[batch_start..];
8645 let batch_end = batch_body
8646 .find("\n pub fn reset")
8647 .unwrap_or(batch_body.len());
8648 let batch_code = &batch_body[..batch_end];
8649
8650 assert!(
8651 batch_code.contains("dispatch_matmul_batch"),
8652 "f32 forward_prefill_batch should use dispatch_matmul_batch"
8653 );
8654 assert!(
8656 !batch_code.contains("dispatch_matmul_q8_batch"),
8657 "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8658 );
8659 assert!(
8660 !batch_code.contains("dispatch_matmul_q4_batch"),
8661 "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8662 );
8663 }
8664
8665 #[test]
8666 fn forward_uses_cpu_embedding_lookup() {
8667 let config = minimal_config();
8668 let model_rs = generate_model_rs(&config).unwrap();
8669
8670 let forward_start = model_rs
8672 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8673 .unwrap();
8674 let forward_body = &model_rs[forward_start..];
8675 let forward_end = forward_body
8676 .find("\n pub fn forward_profile")
8677 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8678 .unwrap_or(forward_body.len());
8679 let forward_code = &forward_body[..forward_end];
8680
8681 assert!(
8683 forward_code.contains("embed_buf.contents()"),
8684 "forward() should access embed_buf via CPU unified memory for embedding lookup"
8685 );
8686 assert!(
8687 forward_code.contains("copy_nonoverlapping"),
8688 "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8689 );
8690 assert!(
8692 !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8693 "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8694 );
8695 }
8696
8697 #[test]
8698 fn forward_profile_method_exists() {
8699 let config = minimal_config();
8700 let model_rs = generate_model_rs(&config).unwrap();
8701
8702 assert!(
8703 model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8704 "MetalModel should have forward_profile() method"
8705 );
8706 assert!(
8708 model_rs.contains("[profile]"),
8709 "forward_profile() should print timing with [profile] prefix"
8710 );
8711 assert!(
8712 model_rs.contains("d_embed"),
8713 "forward_profile() should measure embedding time"
8714 );
8715 assert!(
8716 model_rs.contains("d_layers"),
8717 "forward_profile() should measure layer time"
8718 );
8719 assert!(
8720 model_rs.contains("d_logits"),
8721 "forward_profile() should measure logits time"
8722 );
8723 }
8724
8725 #[test]
8726 fn generated_cli_has_profile_flag() {
8727 let config = minimal_config();
8728 let main_rs = generate_main_rs("test-model", &config).unwrap();
8729
8730 assert!(
8731 main_rs.contains("--profile"),
8732 "CLI should support --profile flag"
8733 );
8734 assert!(
8735 main_rs.contains("forward_profile"),
8736 "CLI should call forward_profile when --profile is set"
8737 );
8738 }
8739
8740 #[test]
8741 fn generated_cli_has_thermal_yield() {
8742 let config = minimal_config();
8743 let main_rs = generate_main_rs("test-model", &config).unwrap();
8744
8745 assert!(
8746 main_rs.contains("yield_now()"),
8747 "CLI generation loop should include thread::yield_now() for thermal management"
8748 );
8749 }
8750
8751 #[test]
8754 fn generated_forward_handles_single_token_prompt() {
8755 let config = minimal_config();
8759 let model_rs = generate_model_rs(&config).unwrap();
8760
8761 let forward_start = model_rs
8763 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8764 .expect("forward() must exist");
8765 let forward_body = &model_rs[forward_start..forward_start + 400];
8766
8767 assert!(
8769 !forward_body.contains("assert!(self.pos > 0"),
8770 "forward() must accept pos=0 (first token with no prefill)"
8771 );
8772
8773 assert!(
8775 model_rs.contains("self.pos"),
8776 "forward() should use self.pos to track sequence position"
8777 );
8778 }
8779
8780 #[test]
8781 fn generated_reset_clears_kv_cache_position() {
8782 let config = minimal_config();
8785 let model_rs = generate_model_rs(&config).unwrap();
8786
8787 let reset_start = model_rs
8788 .find("pub fn reset(&mut self)")
8789 .expect("reset() must exist");
8790 let reset_body = &model_rs[reset_start..reset_start + 200];
8791
8792 assert!(
8794 reset_body.contains("self.pos = 0"),
8795 "reset() must set self.pos = 0"
8796 );
8797
8798 assert!(
8800 reset_body.contains("self.prev_cmd = None"),
8801 "reset() should clear prev_cmd for clean command buffer state"
8802 );
8803 }
8804
8805 #[test]
8806 fn generated_serve_handles_empty_messages_gracefully() {
8807 let config = minimal_config();
8810 let main_rs = generate_main_rs("test-model", &config).unwrap();
8811
8812 let format_fn_start = main_rs
8814 .find("fn format_chat_messages")
8815 .expect("format_chat_messages must exist");
8816 let format_fn_body =
8817 &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8818
8819 assert!(
8821 format_fn_body.contains("for msg in messages"),
8822 "format_chat_messages should iterate over the messages slice"
8823 );
8824 assert!(
8826 format_fn_body.contains("<|im_start|>assistant"),
8827 "format_chat_messages should always append assistant prompt header"
8828 );
8829
8830 let serve_fn_start = main_rs
8832 .find("fn serve(")
8833 .expect("serve function must exist");
8834 let serve_fn_body = &main_rs[serve_fn_start..];
8835 assert!(
8836 serve_fn_body.contains("model.reset()"),
8837 "serve function should reset model between requests"
8838 );
8839 }
8840
8841 #[test]
8842 fn generated_model_forward_increments_pos() {
8843 let config = minimal_config();
8846 let model_rs = generate_model_rs(&config).unwrap();
8847
8848 let forward_start = model_rs
8849 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8850 .unwrap();
8851 let forward_body = &model_rs[forward_start..];
8852 let forward_end = forward_body
8853 .find("\n pub fn forward_profile")
8854 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8855 .or_else(|| forward_body.find("\n fn dispatch_"))
8856 .unwrap_or(forward_body.len());
8857 let forward_code = &forward_body[..forward_end];
8858
8859 assert!(
8860 forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8861 "forward() must increment self.pos after processing a token"
8862 );
8863 }
8864}